Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce usages of TensorUtils<T>::DataType in THC. #8056

Merged
merged 1 commit into from
Jun 2, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 17 additions & 3 deletions aten/src/THC/THCApply.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,14 @@ inline bool getApplyGrid(THCState* state, uint64_t totalElements, dim3& grid) {
return true;
}

template <typename TensorTypeA,
template <typename ScalarTypeA,
typename TensorTypeA,
typename Op>
bool THC_pointwiseApply1(THCState* state,
TensorTypeA* a,
const Op& op,
TensorArgType aType = ReadWrite) {
static_assert(std::is_same<ScalarTypeA, typename TensorUtils<TensorTypeA>::DataType>::value, "ScalarTypeA must match");
if (TensorUtils<TensorTypeA>::getDims(state, a) > MAX_CUTORCH_DIMS) {
return false;
}
Expand Down Expand Up @@ -315,7 +317,9 @@ bool THC_pointwiseApply1(THCState* state,
return true;
}

template <typename TensorTypeA,
template <typename ScalarTypeA,
typename ScalarTypeB,
typename TensorTypeA,
typename TensorTypeB,
typename Op>
bool THC_pointwiseApply2(THCState* state,
Expand All @@ -324,6 +328,9 @@ bool THC_pointwiseApply2(THCState* state,
const Op& op,
TensorArgType aType = ReadWrite,
TensorArgType bType = ReadOnly) {
static_assert(std::is_same<ScalarTypeA, typename TensorUtils<TensorTypeA>::DataType>::value, "ScalarTypeA must match");
static_assert(std::is_same<ScalarTypeB, typename TensorUtils<TensorTypeB>::DataType>::value, "ScalarTypeB must match");

ptrdiff_t totalElements = TensorUtils<TensorTypeA>::getNumElements(state, a);
if (totalElements != TensorUtils<TensorTypeB>::getNumElements(state, b)) {
return false;
Expand Down Expand Up @@ -499,7 +506,10 @@ bool THC_pointwiseApply2(THCState* state,
return true;
}

template <typename TensorTypeA,
template <typename ScalarTypeA,
typename ScalarTypeB,
typename ScalarTypeC,
typename TensorTypeA,
typename TensorTypeB,
typename TensorTypeC,
typename Op>
Expand All @@ -511,6 +521,10 @@ bool THC_pointwiseApply3(THCState* state,
TensorArgType aType = ReadWrite,
TensorArgType bType = ReadOnly,
TensorArgType cType = ReadOnly) {
static_assert(std::is_same<ScalarTypeA, typename TensorUtils<TensorTypeA>::DataType>::value, "ScalarTypeA must match");
static_assert(std::is_same<ScalarTypeB, typename TensorUtils<TensorTypeB>::DataType>::value, "ScalarTypeB must match");
static_assert(std::is_same<ScalarTypeC, typename TensorUtils<TensorTypeC>::DataType>::value, "ScalarTypeB must match");

ptrdiff_t totalElements = TensorUtils<TensorTypeA>::getNumElements(state, a);

if (totalElements != TensorUtils<TensorTypeB>::getNumElements(state, b) ||
Expand Down
6 changes: 4 additions & 2 deletions aten/src/THC/THCTensorCopy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ THC_copyTensor(THCState* state, TensorTypeDst* dst, TensorTypeSrc* src) {
// might be worth it to avoid non-coalesced reads or writes.
if (p2pEnabled) {
bool succ =
THC_pointwiseApply2(
THC_pointwiseApply2<typename TensorUtils<TensorTypeDst>::DataType,
typename TensorUtils<TensorTypeSrc>::DataType>(
state, dst, src,
CopyOp<typename TensorUtils<TensorTypeDst>::DataType,
typename TensorUtils<TensorTypeSrc>::DataType>());
Expand All @@ -139,7 +140,8 @@ THC_copyTensor(THCState* state, TensorTypeDst* dst, TensorTypeSrc* src) {
TensorUtils<TensorTypeDst>::resizeAs(state, srcContig, dst);

bool succ =
THC_pointwiseApply2(
THC_pointwiseApply2<typename TensorUtils<TensorTypeDst>::DataType,
typename TensorUtils<TensorTypeSrc>::DataType>(
state, srcContig, src,
CopyOp<typename TensorUtils<TensorTypeDst>::DataType,
typename TensorUtils<TensorTypeSrc>::DataType>());
Expand Down
4 changes: 2 additions & 2 deletions aten/src/THC/THCTensorIndex.cu
Original file line number Diff line number Diff line change
Expand Up @@ -462,10 +462,10 @@ void dispatchTakePutImpl(THCState *state, TensorType *a, TensorType *b, THCudaLo
auto numel = TensorUtils<TensorType>::getNumElements(state, a);
if (aInfo.isContiguous()) {
auto op = Op<real, IndexType, -2>(aInfo, numel, start, end);
THC_pointwiseApply2(state, b, index, op);
THC_pointwiseApply2<real, int64_t>(state, b, index, op);
} else {
auto op = Op<real, IndexType, -1>(aInfo, numel, start, end);
THC_pointwiseApply2(state, b, index, op);
THC_pointwiseApply2<real, int64_t>(state, b, index, op);
}
}

Expand Down
4 changes: 2 additions & 2 deletions aten/src/THC/THCTensorMathCompare.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ struct TensorNEValueOp {
const T value;
};

template<typename TensorType, typename TensorTypeOut, class Op>
template<typename ScalarTypeOut, typename ScalarType, typename TensorTypeOut, typename TensorType, class Op>
void THC_logicalValue(THCState *state,
TensorTypeOut *self_,
TensorType *src,
Expand All @@ -77,7 +77,7 @@ void THC_logicalValue(THCState *state,
TensorUtils<TensorTypeOut>::resize(state, self_, st, NULL);
THLongStorage_free(st);

if (!THC_pointwiseApply2(state, self_, src, op)) {
if (!THC_pointwiseApply2<ScalarTypeOut, ScalarType>(state, self_, src, op)) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}

Expand Down
4 changes: 2 additions & 2 deletions aten/src/THC/THCTensorMathCompareT.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ struct TensorNEOp {
}
};

template<typename TensorType, typename TensorTypeOut, typename Op>
template<typename ScalarTypeOut, typename ScalarType, typename TensorTypeOut, typename TensorType, typename Op>
void THC_logicalTensor(THCState *state,
TensorTypeOut *self_,
TensorType *src1,
Expand All @@ -64,7 +64,7 @@ void THC_logicalTensor(THCState *state,
TensorUtils<TensorType>::getNumElements(state, src2), 3,
"sizes do not match");

if (!THC_pointwiseApply3(state, self_, src1, src2, op)) {
if (!THC_pointwiseApply3<ScalarTypeOut, ScalarType, ScalarType>(state, self_, src1, src2, op)) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}

Expand Down
7 changes: 4 additions & 3 deletions aten/src/THC/generic/THCTensorCopy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ THCTensor_(copyIgnoringOverlaps)(THCState* state, THCTensor* dst, THCTensor* src
// This is itself invoked by pointwiseApply2 / THCTensor_copy in
// case that there are write overlaps.
// FIXME: really, overlapping writes should be illegal/an error in Torch
THC_pointwiseApply2(
THC_pointwiseApply2<real,
real>(
state, dst, src,
CopyOp<typename TensorUtils<THCTensor>::DataType,
typename TensorUtils<THCTensor>::DataType>(),
CopyOp<real,
real>(),
ReadOnly, /* ignore overwrites */
ReadOnly);
}
Expand Down
2 changes: 1 addition & 1 deletion aten/src/THC/generic/THCTensorIndex.cu
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ void THCTensor_(put)(THCState *state, THCTensor *dst, THCudaLongTensor *index, T
// wrap indices so to replace negative indices
THCudaLongTensor* sorted_index = THCudaLongTensor_new(state);
THCudaLongTensor_resizeAs(state, sorted_index, index);
THC_pointwiseApply2(state, sorted_index, index, WrapIndexOp(dstSize));
THC_pointwiseApply2<int64_t, int64_t>(state, sorted_index, index, WrapIndexOp(dstSize));

THCTensor* sorted_src = THCTensor_(newClone)(state, src);

Expand Down
8 changes: 4 additions & 4 deletions aten/src/THC/generic/THCTensorMasked.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ THCTensor_(maskedFill)(THCState* state,
THCudaByteTensor_nElement(state, mask),
2, "sizes do not match");

if (!THC_pointwiseApply2(state, tensor, mask,
TensorMaskedFillOp<real, unsigned char>(value))) {
if (!THC_pointwiseApply2<real, uint8_t>(state, tensor, mask,
TensorMaskedFillOp<real, unsigned char>(value))) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}

Expand Down Expand Up @@ -88,7 +88,7 @@ THCTensor_(maskedCopy)(THCState* state,

// update `tensor` where `mask` == 1 but pull from `src` at
// maskPrefixSum
bool status = THC_pointwiseApply3(
bool status = THC_pointwiseApply3<real, uint8_t, int64_t>(
state, tensor, mask, maskPrefixSum,
TensorMaskedCopyOp<real, unsigned char, int64_t>(
THCTensor_(data)(state, contigSrc)));
Expand Down Expand Up @@ -158,7 +158,7 @@ THCTensor_(maskedSelect)(THCState* state,
maskPrefixSumData);

// Then copy over the masked elements at their desired output index
bool status = THC_pointwiseApply3(
bool status = THC_pointwiseApply3<uint8_t, int64_t, real>(
state, mask, maskPrefixSum,
src, TensorMaskedSelectOp<real, unsigned char, int64_t>(
THCTensor_(data)(state, tensor)));
Expand Down
4 changes: 2 additions & 2 deletions aten/src/THC/generic/THCTensorMath.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ THCTensor_(fill)(THCState* state, THCTensor *self_, real value)
{
THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, self_));

if (!THC_pointwiseApply1(
if (!THC_pointwiseApply1<real>(
state, self_, TensorFillOp<real>(value))) {
THArgCheck(false, 1, CUTORCH_DIM_WARNING);
}
Expand All @@ -25,7 +25,7 @@ THCTensor_(zero)(THCState *state, THCTensor *self_)
sizeof(real) * THCTensor_(nElement)(state, self_),
THCState_getCurrentStream(state)));
} else {
if (!THC_pointwiseApply1(
if (!THC_pointwiseApply1<real>(
state, self_,
TensorFillOp<real>(ScalarConvert<int, real>::to(0)))) {
THArgCheck(false, 1, CUTORCH_DIM_WARNING);
Expand Down
72 changes: 36 additions & 36 deletions aten/src/THC/generic/THCTensorMathCompare.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,97 +5,97 @@
THC_API void THCTensor_(ltValue)(THCState *state, THCudaByteTensor *self_, THCTensor *src, real value)
{
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src));
THC_logicalValue(state, self_, src,
TensorLTValueOp<typename TensorUtils<THCTensor>::DataType,
unsigned char>(value));
THC_logicalValue<uint8_t, real>(state, self_, src,
TensorLTValueOp<real,
unsigned char>(value));
}

THC_API void THCTensor_(gtValue)(THCState *state, THCudaByteTensor *self_, THCTensor *src, real value)
{
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src));
THC_logicalValue(state, self_, src,
TensorGTValueOp<typename TensorUtils<THCTensor>::DataType,
unsigned char>(value));
THC_logicalValue<uint8_t, real>(state, self_, src,
TensorGTValueOp<real,
unsigned char>(value));
}

THC_API void THCTensor_(leValue)(THCState *state, THCudaByteTensor *self_, THCTensor *src, real value)
{
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src));
THC_logicalValue(state, self_, src,
TensorLEValueOp<typename TensorUtils<THCTensor>::DataType,
unsigned char>(value));
THC_logicalValue<uint8_t, real>(state, self_, src,
TensorLEValueOp<real,
unsigned char>(value));
}

THC_API void THCTensor_(geValue)(THCState *state, THCudaByteTensor *self_, THCTensor *src, real value)
{
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src));
THC_logicalValue(state, self_, src,
TensorGEValueOp<typename TensorUtils<THCTensor>::DataType,
unsigned char>(value));
THC_logicalValue<uint8_t, real>(state, self_, src,
TensorGEValueOp<real,
unsigned char>(value));
}

THC_API void THCTensor_(eqValue)(THCState *state, THCudaByteTensor *self_, THCTensor *src, real value)
{
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src));
THC_logicalValue(state, self_, src,
TensorEQValueOp<typename TensorUtils<THCTensor>::DataType,
unsigned char>(value));
THC_logicalValue<uint8_t, real>(state, self_, src,
TensorEQValueOp<real,
unsigned char>(value));
}

THC_API void THCTensor_(neValue)(THCState *state, THCudaByteTensor *self_, THCTensor *src, real value)
{
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src));
THC_logicalValue(state, self_, src,
TensorNEValueOp<typename TensorUtils<THCTensor>::DataType,
unsigned char>(value));
THC_logicalValue<uint8_t, real>(state, self_, src,
TensorNEValueOp<real,
unsigned char>(value));
}

THC_API void THCTensor_(ltValueT)(THCState *state, THCTensor *self_, THCTensor *src, real value)
{
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src));
THC_logicalValue(state, self_, src,
TensorLTValueOp<typename TensorUtils<THCTensor>::DataType,
typename TensorUtils<THCTensor>::DataType>(value));
THC_logicalValue<real, real>(state, self_, src,
TensorLTValueOp<real,
real>(value));
}

THC_API void THCTensor_(gtValueT)(THCState *state, THCTensor *self_, THCTensor *src, real value)
{
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src));
THC_logicalValue(state, self_, src,
TensorGTValueOp<typename TensorUtils<THCTensor>::DataType,
typename TensorUtils<THCTensor>::DataType>(value));
THC_logicalValue<real, real>(state, self_, src,
TensorGTValueOp<real,
real>(value));
}

THC_API void THCTensor_(leValueT)(THCState *state, THCTensor *self_, THCTensor *src, real value)
{
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src));
THC_logicalValue(state, self_, src,
TensorLEValueOp<typename TensorUtils<THCTensor>::DataType,
typename TensorUtils<THCTensor>::DataType>(value));
THC_logicalValue<real, real>(state, self_, src,
TensorLEValueOp<real,
real>(value));
}

THC_API void THCTensor_(geValueT)(THCState *state, THCTensor *self_, THCTensor *src, real value)
{
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src));
THC_logicalValue(state, self_, src,
TensorGEValueOp<typename TensorUtils<THCTensor>::DataType,
typename TensorUtils<THCTensor>::DataType>(value));
THC_logicalValue<real, real>(state, self_, src,
TensorGEValueOp<real,
real>(value));
}

THC_API void THCTensor_(eqValueT)(THCState *state, THCTensor *self_, THCTensor *src, real value)
{
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src));
THC_logicalValue(state, self_, src,
TensorEQValueOp<typename TensorUtils<THCTensor>::DataType,
typename TensorUtils<THCTensor>::DataType>(value));
THC_logicalValue<real, real>(state, self_, src,
TensorEQValueOp<real,
real>(value));
}

THC_API void THCTensor_(neValueT)(THCState *state, THCTensor *self_, THCTensor *src, real value)
{
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src));
THC_logicalValue(state, self_, src,
TensorNEValueOp<typename TensorUtils<THCTensor>::DataType,
typename TensorUtils<THCTensor>::DataType>(value));
THC_logicalValue<real, real>(state, self_, src,
TensorNEValueOp<real,
real>(value));
}

#endif