Skip to content

Commit

Permalink
Added scalar lists APIs for addcdiv and addcmul (#45932)
Browse files Browse the repository at this point in the history
Summary:
1) Added new APIs:
 _foreach_addcdiv(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, float[] scalars)
 _foreach_addcdiv_(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, float[] scalars)
 _foreach_addcmul(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, float[] scalars)
 _foreach_addcmul_(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, float[] scalars)

2) Updated optimizers to use new APIs

Tested via unit tests

Pull Request resolved: #45932

Reviewed By: navahgar

Differential Revision: D24150306

Pulled By: izdeby

fbshipit-source-id: c2e65dedc95d9d81a2fdd116e41df0accb0b6f26
  • Loading branch information
Iurii Zdebskyi authored and facebook-github-bot committed Oct 14, 2020
1 parent f2e5ae4 commit 8a074af
Show file tree
Hide file tree
Showing 11 changed files with 485 additions and 130 deletions.
118 changes: 68 additions & 50 deletions aten/src/ATen/native/ForeachOpsKernels.cpp

Large diffs are not rendered by default.

75 changes: 68 additions & 7 deletions aten/src/ATen/native/ForeachUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,19 @@ namespace at {
namespace native {
namespace {

void check_nonempty_and_same_length(TensorList tensors1, TensorList tensors2, TensorList tensors3) {
TORCH_CHECK(tensors1.size() > 0, "Tensor list must have at least one tensor.");
TORCH_CHECK(tensors1.size() == tensors2.size(), "Tensor lists must be of the same length, got ", tensors1.size(), " and ", tensors2.size());
TORCH_CHECK(tensors1.size() == tensors3.size(), "Tensor lists must be of the same length, got ", tensors1.size(), " and ", tensors3.size());
}

void check_nonempty_and_same_length(TensorList tensors1, TensorList tensors2, TensorList tensors3, ArrayRef<double> scalars) {
TORCH_CHECK(tensors1.size() > 0, "Tensor list must have at least one tensor.");
TORCH_CHECK(tensors1.size() == tensors2.size(), "Tensor lists must be of the same length, got ", tensors1.size(), " and ", tensors2.size());
TORCH_CHECK(tensors1.size() == tensors3.size(), "Tensor lists must be of the same length, got ", tensors1.size(), " and ", tensors3.size());
TORCH_CHECK(tensors1.size() == scalars.size(), "Tensor list must have same number of elements as scalar list, got ", tensors1.size(), " and ", scalars.size());
}

// Set of foreach API restrictions
// - All tensors must be of the same dtype
// - All corresponding tensors must be of the same size
Expand Down Expand Up @@ -83,6 +96,27 @@ bool can_use_fast_route(TensorList tensors, Scalar scalar) {
return true;
}

bool can_use_fast_route(TensorList tensors) {
TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor.");
auto expected_device = tensors[0].device();

for (auto t : tensors) {
if (t.layout() != at::kStrided) {
return false;
}

if (!t.is_non_overlapping_and_dense()) {
return false;
}

if (t.device() != expected_device) {
return false;
}
}

return true;
}

bool can_use_fast_route(TensorList tensors1, TensorList tensors2) {
auto expected_device = tensors1[0].device();

Expand Down Expand Up @@ -117,27 +151,54 @@ bool can_use_fast_route(TensorList tensors1, TensorList tensors2) {
return true;
}

bool can_use_fast_route(TensorList tensors) {
TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor.");
auto expected_device = tensors[0].device();
bool can_use_fast_route(TensorList tensors1, TensorList tensors2, TensorList tensors3) {
auto expected_device = tensors1[0].device();

for (auto t : tensors) {
if (t.layout() != at::kStrided) {
for (int64_t i = 0; i < tensors1.size(); i++) {
TORCH_CHECK(tensors1[i].sizes() == tensors2[i].sizes(),
"Corresponding tensors from tensor lists have different sizes, got ", tensors1[i].sizes(), " and ", tensors2[i].sizes());

TORCH_CHECK(tensors1[i].sizes() == tensors3[i].sizes(),
"Corresponding tensors from tensor lists have different sizes, got ", tensors1[i].sizes(), " and ", tensors3[i].sizes());

if (tensors1[i].device() != expected_device ||
tensors2[i].device() != expected_device ||
tensors3[i].device() != expected_device) {
return false;
}

if (!t.is_non_overlapping_and_dense()) {
if (tensors1[i].layout() != at::kStrided ||
tensors2[i].layout() != at::kStrided ||
tensors3[i].layout() != at::kStrided) {
return false;
}

if (t.device() != expected_device) {
if (tensors1[i].device() != expected_device ||
tensors2[i].device() != expected_device ||
tensors3[i].device() != expected_device) {
return false;
}

if (tensors1[i].strides() != tensors2[i].strides() ||
tensors1[i].strides() != tensors3[i].strides()) {
return false;
}

if (!tensors1[i].is_non_overlapping_and_dense() ||
!tensors2[i].is_non_overlapping_and_dense() ||
!tensors3[i].is_non_overlapping_and_dense()) {
return false;
}
}

return true;
}

bool can_use_fast_route(TensorList tensors1, TensorList tensors2, TensorList tensors3, ArrayRef<double> scalars) {
TORCH_CHECK(tensors1.size() == scalars.size(), "Tensor list must have same number of elements as scalar list.");
return can_use_fast_route(tensors1, tensors2, tensors3);
}

bool can_use_fast_route(TensorList tensors, ArrayRef<double> scalars) {
TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor.");
TORCH_CHECK(scalars.size() > 0, "Scalars list must have at least one value.");
Expand Down
156 changes: 156 additions & 0 deletions aten/src/ATen/native/cuda/ForeachFunctors.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,162 @@ struct BinaryOpListFunctor {
}
};

template<typename T>
struct PointwiseOpScalarListFunctor_ {
using opmath_t = typename get_opmath_t<T>::opmath_t;
template<typename Op> __device__ __forceinline__ void operator() (
int chunk_size,
TensorListScalarListMetadata<opmath_t, 3>& tl,
Op op) {
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];

T* x = (T*)tl.addresses[0][tensor_loc];
x += chunk_idx * chunk_size;

T* y = (T*)tl.addresses[1][tensor_loc];
y += chunk_idx * chunk_size;

T* z = (T*)tl.addresses[2][tensor_loc];
z += chunk_idx * chunk_size;

opmath_t scalar = tl.scalar_vals[tensor_loc];

n -= chunk_idx * chunk_size;

T r_x[kILP];
T r_y[kILP];
T r_z[kILP];

// to make things simple, we put aligned case in a different code path
if(n % kILP == 0 && chunk_size % kILP == 0 && is_aligned(x) && is_aligned(y) && is_aligned(z)) {
for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
// load
load_store(r_x, x, 0 , i_start);
load_store(r_y, y, 0 , i_start);
load_store(r_z, z, 0 , i_start);
#pragma unroll
for(int ii = 0; ii < kILP; ii++) {
r_x[ii] = static_cast<T>(static_cast<opmath_t>(r_x[ii]) +
scalar * op(static_cast<opmath_t>(r_y[ii]),
static_cast<opmath_t>(r_z[ii])));
}
// store
load_store(x, r_x, i_start, 0);
}
}
else {
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
#pragma unroll
for(int ii = 0; ii < kILP; ii++) {
r_x[ii] = 0;
r_y[ii] = 0;
r_z[ii] = 0;
int i = i_start + threadIdx.x + ii * blockDim.x;
if(i < n && i < chunk_size) {
r_x[ii] = x[i];
r_y[ii] = y[i];
r_z[ii] = z[i];
}
}
#pragma unroll
for(int ii = 0; ii < kILP; ii++) {
r_x[ii] = static_cast<T>(static_cast<opmath_t>(r_x[ii]) +
scalar * op(static_cast<opmath_t>(r_y[ii]),
static_cast<opmath_t>(r_z[ii])));
}
#pragma unroll
for(int ii = 0; ii < kILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if(i < n && i < chunk_size)
x[i] = r_x[ii];
}
}
}
}
};

template<typename T>
struct PointwiseOpScalarListFunctor {
using opmath_t = typename get_opmath_t<T>::opmath_t;
template<typename Op> __device__ __forceinline__ void operator() (
int chunk_size,
TensorListScalarListMetadata<opmath_t, 4>& tl,
Op op) {
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];

T* x = (T*)tl.addresses[0][tensor_loc];
x += chunk_idx * chunk_size;

T* y = (T*)tl.addresses[1][tensor_loc];
y += chunk_idx * chunk_size;

T* z = (T*)tl.addresses[2][tensor_loc];
z += chunk_idx * chunk_size;

T* out = (T*)tl.addresses[3][tensor_loc];
out += chunk_idx * chunk_size;

opmath_t scalar = tl.scalar_vals[tensor_loc];

n -= chunk_idx * chunk_size;

T r_x[kILP];
T r_y[kILP];
T r_z[kILP];

// to make things simple, we put aligned case in a different code path
if(n % kILP == 0 && chunk_size % kILP == 0 && is_aligned(x) && is_aligned(y) && is_aligned(z) && is_aligned(out)) {
for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
// load
load_store(r_x, x, 0 , i_start);
load_store(r_y, y, 0 , i_start);
load_store(r_z, z, 0 , i_start);
#pragma unroll
for(int ii = 0; ii < kILP; ii++) {
r_x[ii] = static_cast<T>(static_cast<opmath_t>(r_x[ii]) +
scalar * op(static_cast<opmath_t>(r_y[ii]),
static_cast<opmath_t>(r_z[ii])));
}
// store
load_store(out, r_x, i_start, 0);
}
}
else {
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
#pragma unroll
for(int ii = 0; ii < kILP; ii++) {
r_x[ii] = 0;
r_y[ii] = 0;
r_z[ii] = 0;

int i = i_start + threadIdx.x + ii * blockDim.x;
if(i < n && i < chunk_size) {
r_x[ii] = x[i];
r_y[ii] = y[i];
r_z[ii] = z[i];
}
}
#pragma unroll
for(int ii = 0; ii < kILP; ii++) {
r_x[ii] = static_cast<T>(static_cast<opmath_t>(r_x[ii]) +
scalar * op(static_cast<opmath_t>(r_y[ii]),
static_cast<opmath_t>(r_z[ii])));
}
#pragma unroll
for(int ii = 0; ii < kILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if(i < n && i < chunk_size)
out[i] = r_x[ii];
}
}
}
}
};

} // namespace

}} // namespace at::native

0 comments on commit 8a074af

Please sign in to comment.