Skip to content

Commit

Permalink
Replace most remaining usages of TensorUtils<T>::DataType. (pytorch#8124
Browse files Browse the repository at this point in the history
)

As in pytorch#8056, this doesn't work with a single TensorImpl type.
This replaces the usages of with a templatized parameter and static_asserts that the new and old are equal.

After this we can get rid of the old template parameter, but I want to ensure they are equivalent across all builds first.
  • Loading branch information
gchanan authored and weiyangfb committed Jun 11, 2018
1 parent c65b6c5 commit 44e94b7
Show file tree
Hide file tree
Showing 19 changed files with 334 additions and 324 deletions.
120 changes: 60 additions & 60 deletions aten/src/THC/THCApply.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,10 @@ bool THC_pointwiseApply1(THCState* state,
// index can be similarly collapsed. That is what this unrolling is for.
#define HANDLE_CASE(TYPE, A) \
kernelPointwiseApply1<Op, \
typename TensorUtils<TensorTypeA>::DataType, \
ScalarTypeA, \
TYPE, A> \
<<<grid, block, 0, THCState_getCurrentStream(state)>>>( \
OffsetInfo<typename TensorUtils<TensorTypeA>::DataType, TYPE, A> \
OffsetInfo<ScalarTypeA, TYPE, A> \
(aInfo), \
(TYPE) totalElements, op);

Expand All @@ -260,8 +260,8 @@ bool THC_pointwiseApply1(THCState* state,
// We also use unsigned index math in the kernel, as signed div/mod has
// additional overhead.
if (TensorUtils<TensorTypeA>::canUse32BitIndexMath(state, a)) {
TensorInfo<typename TensorUtils<TensorTypeA>::DataType, unsigned int> aInfo =
getTensorInfo<TensorTypeA, unsigned int>(state, a);
TensorInfo<ScalarTypeA, unsigned int> aInfo =
getTensorInfo<ScalarTypeA, TensorTypeA, unsigned int>(state, a);
rearrangeDims(&aInfo);
aInfo.collapseDims();
#if CUDA_VERSION < 9000
Expand All @@ -271,8 +271,8 @@ bool THC_pointwiseApply1(THCState* state,
#endif
HANDLE_A_CASE(unsigned int, aInfo.dims);
} else {
TensorInfo<typename TensorUtils<TensorTypeA>::DataType, uint64_t> aInfo =
getTensorInfo<TensorTypeA, uint64_t>(state, a);
TensorInfo<ScalarTypeA, uint64_t> aInfo =
getTensorInfo<ScalarTypeA, TensorTypeA, uint64_t>(state, a);
rearrangeDims(&aInfo);
aInfo.collapseDims();

Expand All @@ -281,10 +281,10 @@ bool THC_pointwiseApply1(THCState* state,
large (64-bit indexed) tensors to reduce compilation time.
*/
if (aInfo.dims == 1) {
OffsetInfo<typename TensorUtils<TensorTypeA>::DataType, uint64_t, 1>
OffsetInfo<ScalarTypeA, uint64_t, 1>
aOffset(aInfo);
kernelPointwiseApply1<Op,
typename TensorUtils<TensorTypeA>::DataType,
ScalarTypeA,
uint64_t, 1>
<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
aOffset, (uint64_t) totalElements, op);
Expand All @@ -293,10 +293,10 @@ bool THC_pointwiseApply1(THCState* state,
#if CUDA_VERSION < 9000
grid.x = min(THCState_getCurrentDeviceProperties(state)->multiProcessorCount * THC_APPLY_BLOCKS_PER_SM , grid.x);
#endif
OffsetInfo<typename TensorUtils<TensorTypeA>::DataType, uint64_t, -1>
OffsetInfo<ScalarTypeA, uint64_t, -1>
aOffset(aInfo);
kernelPointwiseApply1<Op,
typename TensorUtils<TensorTypeA>::DataType,
ScalarTypeA,
uint64_t, -1>
<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
aOffset, (uint64_t) totalElements, op);
Expand Down Expand Up @@ -384,13 +384,13 @@ bool THC_pointwiseApply2(THCState* state,
// index can be similarly collapsed. That is what this unrolling is for.
#define HANDLE_CASE(TYPE, A, B) \
kernelPointwiseApply2<Op, \
typename TensorUtils<TensorTypeA>::DataType, \
typename TensorUtils<TensorTypeB>::DataType, \
ScalarTypeA, \
ScalarTypeB, \
TYPE, A, B> \
<<<grid, block, 0, THCState_getCurrentStream(state)>>>( \
OffsetInfo<typename TensorUtils<TensorTypeA>::DataType, TYPE, A> \
OffsetInfo<ScalarTypeA, TYPE, A> \
(aInfo), \
OffsetInfo<typename TensorUtils<TensorTypeB>::DataType, TYPE, B> \
OffsetInfo<ScalarTypeB, TYPE, B> \
(bInfo), \
(TYPE) totalElements, op);

Expand Down Expand Up @@ -424,11 +424,11 @@ bool THC_pointwiseApply2(THCState* state,

if (TensorUtils<TensorTypeA>::canUse32BitIndexMath(state, a) &&
TensorUtils<TensorTypeB>::canUse32BitIndexMath(state, b)) {
TensorInfo<typename TensorUtils<TensorTypeA>::DataType, unsigned int> aInfo =
getTensorInfo<TensorTypeA, unsigned int>(state, a);
TensorInfo<ScalarTypeA, unsigned int> aInfo =
getTensorInfo<ScalarTypeA, TensorTypeA, unsigned int>(state, a);

TensorInfo<typename TensorUtils<TensorTypeB>::DataType, unsigned int> bInfo =
getTensorInfo<TensorTypeB, unsigned int>(state, b);
TensorInfo<ScalarTypeB, unsigned int> bInfo =
getTensorInfo<ScalarTypeB, TensorTypeB, unsigned int>(state, b);

rearrangeDims(&aInfo, &bInfo);
aInfo.collapseDims();
Expand All @@ -440,11 +440,11 @@ bool THC_pointwiseApply2(THCState* state,

HANDLE_A_CASE(unsigned int, aInfo.dims, bInfo.dims);
} else {
TensorInfo<typename TensorUtils<TensorTypeA>::DataType, uint64_t> aInfo =
getTensorInfo<TensorTypeA, uint64_t>(state, a);
TensorInfo<ScalarTypeA, uint64_t> aInfo =
getTensorInfo<ScalarTypeA, TensorTypeA, uint64_t>(state, a);

TensorInfo<typename TensorUtils<TensorTypeB>::DataType, uint64_t> bInfo =
getTensorInfo<TensorTypeB, uint64_t>(state, b);
TensorInfo<ScalarTypeB, uint64_t> bInfo =
getTensorInfo<ScalarTypeB, TensorTypeB, uint64_t>(state, b);

rearrangeDims(&aInfo, &bInfo);
aInfo.collapseDims();
Expand All @@ -455,27 +455,27 @@ bool THC_pointwiseApply2(THCState* state,
large (64-bit indexed) tensors to reduce compilation time.
*/
if (aInfo.dims == 1 && bInfo.dims == 1) {
OffsetInfo<typename TensorUtils<TensorTypeA>::DataType, uint64_t, 1>
OffsetInfo<ScalarTypeA, uint64_t, 1>
aOffset(aInfo);
OffsetInfo<typename TensorUtils<TensorTypeB>::DataType, uint64_t, 1>
OffsetInfo<ScalarTypeB, uint64_t, 1>
bOffset(bInfo);
kernelPointwiseApply2<Op,
typename TensorUtils<TensorTypeA>::DataType,
typename TensorUtils<TensorTypeB>::DataType,
ScalarTypeA,
ScalarTypeB,
uint64_t, 1, 1>
<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
aOffset, bOffset, (uint64_t) totalElements, op);
} else {
#if CUDA_VERSION < 9000
grid.x = min(THCState_getCurrentDeviceProperties(state)->multiProcessorCount * THC_APPLY_BLOCKS_PER_SM , grid.x);
#endif
OffsetInfo<typename TensorUtils<TensorTypeA>::DataType, uint64_t, -1>
OffsetInfo<ScalarTypeA, uint64_t, -1>
aOffset(aInfo);
OffsetInfo<typename TensorUtils<TensorTypeB>::DataType, uint64_t, -1>
OffsetInfo<ScalarTypeB, uint64_t, -1>
bOffset(bInfo);
kernelPointwiseApply2<Op,
typename TensorUtils<TensorTypeA>::DataType,
typename TensorUtils<TensorTypeB>::DataType,
ScalarTypeA,
ScalarTypeB,
uint64_t, -1, -1>
<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
aOffset, bOffset, (uint64_t) totalElements, op);
Expand Down Expand Up @@ -580,16 +580,16 @@ bool THC_pointwiseApply3(THCState* state,

#define HANDLE_CASE(TYPE, A, B, C) \
kernelPointwiseApply3<Op, \
typename TensorUtils<TensorTypeA>::DataType, \
typename TensorUtils<TensorTypeB>::DataType, \
typename TensorUtils<TensorTypeC>::DataType, \
ScalarTypeA, \
ScalarTypeB, \
ScalarTypeC, \
TYPE, A, B, C> \
<<<grid, block, 0, THCState_getCurrentStream(state)>>>( \
OffsetInfo<typename TensorUtils<TensorTypeA>::DataType, TYPE, A> \
OffsetInfo<ScalarTypeA, TYPE, A> \
(aInfo), \
OffsetInfo<typename TensorUtils<TensorTypeB>::DataType, TYPE, B> \
OffsetInfo<ScalarTypeB, TYPE, B> \
(bInfo), \
OffsetInfo<typename TensorUtils<TensorTypeC>::DataType, TYPE, C> \
OffsetInfo<ScalarTypeC, TYPE, C> \
(cInfo), \
(TYPE) totalElements, op);

Expand Down Expand Up @@ -638,14 +638,14 @@ bool THC_pointwiseApply3(THCState* state,
if (TensorUtils<TensorTypeA>::canUse32BitIndexMath(state, a) &&
TensorUtils<TensorTypeB>::canUse32BitIndexMath(state, b) &&
TensorUtils<TensorTypeC>::canUse32BitIndexMath(state, c)) {
TensorInfo<typename TensorUtils<TensorTypeA>::DataType, unsigned int> aInfo =
getTensorInfo<TensorTypeA, unsigned int>(state, a);
TensorInfo<ScalarTypeA, unsigned int> aInfo =
getTensorInfo<ScalarTypeA, TensorTypeA, unsigned int>(state, a);

TensorInfo<typename TensorUtils<TensorTypeB>::DataType, unsigned int> bInfo =
getTensorInfo<TensorTypeB, unsigned int>(state, b);
TensorInfo<ScalarTypeB, unsigned int> bInfo =
getTensorInfo<ScalarTypeB, TensorTypeB, unsigned int>(state, b);

TensorInfo<typename TensorUtils<TensorTypeC>::DataType, unsigned int> cInfo =
getTensorInfo<TensorTypeC, unsigned int>(state, c);
TensorInfo<ScalarTypeC, unsigned int> cInfo =
getTensorInfo<ScalarTypeC, TensorTypeC, unsigned int>(state, c);

rearrangeDims(&aInfo, &bInfo, &cInfo);
aInfo.collapseDims();
Expand All @@ -658,14 +658,14 @@ bool THC_pointwiseApply3(THCState* state,
#endif
HANDLE_A_CASE(unsigned int, aInfo.dims, bInfo.dims, cInfo.dims);
} else {
TensorInfo<typename TensorUtils<TensorTypeA>::DataType, uint64_t> aInfo =
getTensorInfo<TensorTypeA, uint64_t>(state, a);
TensorInfo<ScalarTypeA, uint64_t> aInfo =
getTensorInfo<ScalarTypeA, TensorTypeA, uint64_t>(state, a);

TensorInfo<typename TensorUtils<TensorTypeB>::DataType, uint64_t> bInfo =
getTensorInfo<TensorTypeB, uint64_t>(state, b);
TensorInfo<ScalarTypeB, uint64_t> bInfo =
getTensorInfo<ScalarTypeB, TensorTypeB, uint64_t>(state, b);

TensorInfo<typename TensorUtils<TensorTypeC>::DataType, uint64_t> cInfo =
getTensorInfo<TensorTypeC, uint64_t>(state, c);
TensorInfo<ScalarTypeC, uint64_t> cInfo =
getTensorInfo<ScalarTypeC, TensorTypeC, uint64_t>(state, c);

rearrangeDims(&aInfo, &bInfo, &cInfo);
aInfo.collapseDims();
Expand All @@ -677,16 +677,16 @@ bool THC_pointwiseApply3(THCState* state,
large (64-bit indexed) tensors to reduce compilation time.
*/
if (aInfo.dims == 1 && bInfo.dims == 1 && cInfo.dims == 1) {
OffsetInfo<typename TensorUtils<TensorTypeA>::DataType, uint64_t, 1>
OffsetInfo<ScalarTypeA, uint64_t, 1>
aOffset(aInfo);
OffsetInfo<typename TensorUtils<TensorTypeB>::DataType, uint64_t, 1>
OffsetInfo<ScalarTypeB, uint64_t, 1>
bOffset(bInfo);
OffsetInfo<typename TensorUtils<TensorTypeC>::DataType, uint64_t, 1>
OffsetInfo<ScalarTypeC, uint64_t, 1>
cOffset(cInfo);
kernelPointwiseApply3<Op,
typename TensorUtils<TensorTypeA>::DataType,
typename TensorUtils<TensorTypeB>::DataType,
typename TensorUtils<TensorTypeC>::DataType,
ScalarTypeA,
ScalarTypeB,
ScalarTypeC,
uint64_t, 1, 1, 1>
<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
aOffset, bOffset, cOffset, (uint64_t) totalElements, op);
Expand All @@ -695,16 +695,16 @@ bool THC_pointwiseApply3(THCState* state,
grid.x = min(THCState_getCurrentDeviceProperties(state)->multiProcessorCount * THC_APPLY_BLOCKS_PER_SM , grid.x);
#endif

OffsetInfo<typename TensorUtils<TensorTypeA>::DataType, uint64_t, -1>
OffsetInfo<ScalarTypeA, uint64_t, -1>
aOffset(aInfo);
OffsetInfo<typename TensorUtils<TensorTypeB>::DataType, uint64_t, -1>
OffsetInfo<ScalarTypeB, uint64_t, -1>
bOffset(bInfo);
OffsetInfo<typename TensorUtils<TensorTypeC>::DataType, uint64_t, -1>
OffsetInfo<ScalarTypeC, uint64_t, -1>
cOffset(cInfo);
kernelPointwiseApply3<Op,
typename TensorUtils<TensorTypeA>::DataType,
typename TensorUtils<TensorTypeB>::DataType,
typename TensorUtils<TensorTypeC>::DataType,
ScalarTypeA,
ScalarTypeB,
ScalarTypeC,
uint64_t, -1, -1, -1>
<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
aOffset, bOffset, cOffset, (uint64_t) totalElements, op);
Expand Down
26 changes: 14 additions & 12 deletions aten/src/THC/THCReduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,8 @@ inline bool getContigReduceGrid(ptrdiff_t elements, dim3& grid) {

// Performs a reduction out[..., 0, ...] = reduce_i(modify(in[..., i, ...])) for
// all in where i and the out's 0 are indexed at dimension `dim`
template <typename TensorType,
template <typename ScalarType,
typename TensorType,
typename ModifyOp,
typename ReduceOp,
typename FinalizeOp,
Expand All @@ -281,6 +282,7 @@ bool THC_reduceDim(THCState* state,
AccT init,
int dim,
int keepdim) {
static_assert(std::is_same<ScalarType, typename TensorUtils<TensorType>::DataType>::value, "ScalarType must match");
ptrdiff_t inElements = TensorUtils<TensorType>::getNumElements(state, in);

int64_t reductionSize = TensorUtils<TensorType>::getSize(state, in, dim);
Expand Down Expand Up @@ -360,7 +362,7 @@ bool THC_reduceDim(THCState* state,
// index can be similarly collapsed. That is what this unrolling is for.
#define HANDLE_CASE(TYPE, OUT, IN) \
if (contigReduction) { \
kernelReduceContigDim<typename TensorUtils<TensorType>::DataType, \
kernelReduceContigDim<ScalarType, \
TYPE, AccT, ModifyOp, ReduceOp, FinalizeOp, \
OUT, IN> \
<<<grid, block, smemSize, THCState_getCurrentStream(state)>>> \
Expand All @@ -369,15 +371,15 @@ bool THC_reduceDim(THCState* state,
} else { \
if(block.y == 1){ \
kernelReduceNoncontigDim< \
typename TensorUtils<TensorType>::DataType, \
ScalarType, \
TYPE, AccT, ModifyOp, ReduceOp, FinalizeOp, \
OUT, IN> \
<<<grid, block, 0, THCState_getCurrentStream(state)>>> \
(outInfo, inInfo, reductionStride, reductionSize, \
(TYPE) outElements, init, modifyOp, reduceOp, finalizeOp); \
}else{ \
kernelReduceNoncontigDim_shared< \
typename TensorUtils<TensorType>::DataType, \
ScalarType, \
TYPE, AccT, ModifyOp, ReduceOp, FinalizeOp, \
OUT, IN> \
<<<grid, block, 0, THCState_getCurrentStream(state)>>> \
Expand Down Expand Up @@ -418,26 +420,26 @@ bool THC_reduceDim(THCState* state,

if (TensorUtils<TensorType>::canUse32BitIndexMath(state, out) &&
TensorUtils<TensorType>::canUse32BitIndexMath(state, in)) {
TensorInfo<typename TensorUtils<TensorType>::DataType,
TensorInfo<ScalarType,
unsigned int> outInfo =
getTensorInfo<TensorType, unsigned int>(state, out);
getTensorInfo<ScalarType, TensorType, unsigned int>(state, out);
outInfo.collapseDims();

TensorInfo<typename TensorUtils<TensorType>::DataType,
TensorInfo<ScalarType,
unsigned int> inInfo =
getTensorInfo<TensorType, unsigned int>(state, in);
getTensorInfo<ScalarType, TensorType, unsigned int>(state, in);
inInfo.reduceDim(dim);
inInfo.collapseDims();
HANDLE_OUT_CASE(unsigned int, outInfo.dims, inInfo.dims);
} else {
TensorInfo<typename TensorUtils<TensorType>::DataType,
TensorInfo<ScalarType,
uint64_t> outInfo =
getTensorInfo<TensorType, uint64_t>(state, out);
getTensorInfo<ScalarType, TensorType, uint64_t>(state, out);
outInfo.collapseDims();

TensorInfo<typename TensorUtils<TensorType>::DataType,
TensorInfo<ScalarType,
uint64_t> inInfo =
getTensorInfo<TensorType, uint64_t>(state, in);
getTensorInfo<ScalarType, TensorType, uint64_t>(state, in);
inInfo.reduceDim(dim);
inInfo.collapseDims();

Expand Down
14 changes: 8 additions & 6 deletions aten/src/THC/THCReduceAll.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ void callReduceAll(THCState* state,

// Reduces the entire tensor to one value. `out` points to
// host-resident memory.
template <typename TensorType,
template <typename ScalarType,
typename TensorType,
typename ModifyOp,
typename ReduceOp,
typename AccT>
Expand All @@ -230,6 +231,7 @@ bool THC_reduceAll(THCState* state,
AccT init,
AccT* out,
int outOnDevice) {
static_assert(std::is_same<ScalarType, typename TensorUtils<TensorType>::DataType>::value, "ScalarTypeA must match");
ptrdiff_t inElements = TensorUtils<TensorType>::getNumElements(state, in);

if (TensorUtils<TensorType>::getDims(state, in) > MAX_CUTORCH_DIMS) {
Expand Down Expand Up @@ -261,7 +263,7 @@ bool THC_reduceAll(THCState* state,
// dimension, and the loop to translate the linear index to the array
// index can be similarly collapsed. That is what this unrolling is for.
#define HANDLE_CASE(TYPE, IN) \
callReduceAll<typename TensorUtils<TensorType>::DataType, \
callReduceAll<ScalarType, \
TYPE, AccT, ModifyOp, ReduceOp, IN>( \
state, inInfo, inElements, init, modifyOp, \
reduceOp, devOut);
Expand All @@ -282,15 +284,15 @@ bool THC_reduceAll(THCState* state,
}

if (TensorUtils<TensorType>::canUse32BitIndexMath(state, in)) {
TensorInfo<typename TensorUtils<TensorType>::DataType, unsigned int> inInfo =
getTensorInfo<TensorType, unsigned int>(state, in);
TensorInfo<ScalarType, unsigned int> inInfo =
getTensorInfo<ScalarType, TensorType, unsigned int>(state, in);
inInfo.collapseDims();

HANDLE_IN_CASE(unsigned int, inInfo.dims);
} else {
TensorInfo<typename TensorUtils<TensorType>::DataType,
TensorInfo<ScalarType,
uint64_t> inInfo =
getTensorInfo<TensorType, uint64_t>(state, in);
getTensorInfo<ScalarType, TensorType, uint64_t>(state, in);
inInfo.collapseDims();

/*
Expand Down

0 comments on commit 44e94b7

Please sign in to comment.