Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update TH, THC, THNN, THCUNN #452

Merged
merged 23 commits into from
Jan 14, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
5340291
Update FindARM.cmake
temerick Jan 3, 2017
71cef62
Fix condition for threadArgErrorHandler
colesbury Jan 5, 2017
d186fdb
Fix THHalf issues with MSVC.
gchanan Jan 5, 2017
35e1adf
documentation parity with torch7 for catArray impl
killeent Jan 6, 2017
d070178
Instantiate 128kb of scratch space in GPU memory per-device by default
killeent Jan 9, 2017
35758f5
Get rid of a few unused imports.
gchanan Jan 9, 2017
17c998e
fixing arm64 build
soumith Jan 10, 2017
68e2769
Re-route thrust memory allocation to THCudaMalloc / THCudaFree
gchanan Jan 10, 2017
4a8906d
Add THCThrustAllocator.cuh to install files to downstream projects ca…
gchanan Jan 10, 2017
5065197
Merge pull request #666 from gchanan/thrustalloc
soumith Jan 10, 2017
2b88d85
Re-route thrust memory allocation to THCudaMalloc / THCudaFree in cunn.
gchanan Jan 10, 2017
3e91c5e
Merge pull request #668 from gchanan/thrustalloc
soumith Jan 11, 2017
b4bb4b6
simd.h: really fix the arm64 (i.e. Aarch64) build
cdluminate Jan 11, 2017
82088a8
parallelizing catArray to multiple tensors per kernel (#635)
killeent Jan 12, 2017
b076944
Fix for atomicAdd(double) for CUDA_VERSION < 8000
pavanky Jan 13, 2017
f467848
Avoid strict aliasing warning in float/half conversions.
gchanan Jan 13, 2017
5171e56
Ensure atomicAdd(double) is visible to host side code
pavanky Jan 13, 2017
e67b525
Merge pull request #911 from gchanan/convWarning
soumith Jan 13, 2017
eab5c19
Avoid strict aliasing warning in float/half conversions.
gchanan Jan 13, 2017
ca74bb1
Merge pull request #675 from pavanky/more-atomic-fix
soumith Jan 13, 2017
b8a5b1e
Merge commit 'e67b525388a5ae11ed243e94bbc25b4934b03a66'
soumith Jan 13, 2017
b5c9f5c
Merge commit 'ca74bb17b8823d74b83433e2743f23e572501c72'
soumith Jan 13, 2017
fd600b1
Merge commit '2b88d85505d7317f980e69201e72694d6d5905a4'
soumith Jan 13, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion torch/lib/TH/THGeneral.c
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ void _THArgCheck(const char *file, int line, int condition, int argNumber, const
snprintf(msg + n, 2048 - n, " at %s:%d", file, line);
}

if (threadArgErrorHandlerData)
if (threadArgErrorHandler)
(*threadArgErrorHandler)(argNumber, msg, threadArgErrorHandlerData);
else
(*defaultArgErrorHandler)(argNumber, msg, defaultArgErrorHandlerData);
Expand Down
8 changes: 5 additions & 3 deletions torch/lib/TH/THHalf.c
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,17 @@ float TH_half2float(THHalf h)
}

int temp = ((sign << 31) | (exponent << 23) | mantissa);

return *((float*)((void*)&temp));
float x;
memcpy(&x,&temp,sizeof(float));
return x;
}

THHalf TH_float2half(float f)
{
THHalf ret;

unsigned x = *((int*)(void*)(&f));
unsigned x;
memcpy(&x,&f,sizeof(f));
unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1;
unsigned sign, exponent, mantissa;

Expand Down
4 changes: 2 additions & 2 deletions torch/lib/TH/cmake/FindARM.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ if(NOT NEON_FOUND)
MESSAGE(STATUS "Could not find hardware support for NEON on this machine.")
endif(NOT NEON_FOUND)
if(NOT CORTEXA8_FOUND)
MESSAGE(STATUS "No OMAP3 processor on this on this machine.")
MESSAGE(STATUS "No OMAP3 processor on this machine.")
endif(NOT CORTEXA8_FOUND)
if(NOT CORTEXA9_FOUND)
MESSAGE(STATUS "No OMAP4 processor on this on this machine.")
MESSAGE(STATUS "No OMAP4 processor on this machine.")
endif(NOT CORTEXA9_FOUND)
mark_as_advanced(NEON_FOUND)
10 changes: 8 additions & 2 deletions torch/lib/TH/generic/THTensorCopy.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

void THTensor_(copy)(THTensor *tensor, THTensor *src)
{
TH_TENSOR_APPLY2(real, tensor, real, src, *tensor_data = (real)(*src_data);)
TH_TENSOR_APPLY2(real, tensor, real, src, *tensor_data = *src_data;)
}

#define IMPLEMENT_THTensor_COPY(TYPENAMESRC, TYPE_SRC) \
Expand All @@ -25,6 +25,12 @@ void THTensor_(copy##TYPENAMESRC)(THTensor *tensor, TH##TYPENAMESRC##Tensor *src
TH_TENSOR_APPLY2(real, tensor, TYPE_SRC, src, *tensor_data = (real)TH_half2float(*src_data);) \
}

#define IMPLEMENT_THTensor_COPY_TO_FROM_HALF(TYPENAMESRC, TYPE_SRC) \
void THTensor_(copy##TYPENAMESRC)(THTensor *tensor, TH##TYPENAMESRC##Tensor *src) \
{ \
TH_TENSOR_APPLY2(real, tensor, TYPE_SRC, src, *tensor_data = *src_data;) \
}

#ifndef TH_REAL_IS_HALF
IMPLEMENT_THTensor_COPY(Byte, unsigned char)
IMPLEMENT_THTensor_COPY(Char, char)
Expand All @@ -36,7 +42,7 @@ IMPLEMENT_THTensor_COPY(Double, double)
IMPLEMENT_THTensor_COPY_FROM_HALF(Half, THHalf)
#else
/* only allow pass-through for Half */
IMPLEMENT_THTensor_COPY(Half, THHalf)
IMPLEMENT_THTensor_COPY_TO_FROM_HALF(Half, THHalf)
IMPLEMENT_THTensor_COPY_TO_HALF(Byte, unsigned char)
IMPLEMENT_THTensor_COPY_TO_HALF(Char, char)
IMPLEMENT_THTensor_COPY_TO_HALF(Short, short)
Expand Down
4 changes: 2 additions & 2 deletions torch/lib/TH/generic/simd/simd.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ enum SIMDExtensions
};


#if defined(__arm__)
#if defined(__arm__) || defined(__aarch64__) // incl. armel, armhf, arm64

#if defined(__NEON__)

Expand All @@ -80,7 +80,7 @@ static inline uint32_t detectHostSIMDExtensions()
return SIMDExtension_VSX;
}

#else
#else //PPC64 without VSX

static inline uint32_t detectHostSIMDExtensions()
{
Expand Down
1 change: 1 addition & 0 deletions torch/lib/THC/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ INSTALL(FILES
THCTensorTypeUtils.cuh
THCTensorRandom.cuh
THCTensorMathMagma.cuh
THCThrustAllocator.cuh
DESTINATION "${THC_INSTALL_INCLUDE_SUBDIR}/THC")

INSTALL(FILES
Expand Down
5 changes: 4 additions & 1 deletion torch/lib/THC/THCAtomics.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ static inline __device__ void atomicAdd(half *address, half val) {
}
#endif

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
// from CUDA C Programmic Guide
static inline __device__ void atomicAdd(double *address, double val) {
unsigned long long int* address_as_ull = (unsigned long long int*)address;
Expand All @@ -126,6 +126,9 @@ static inline __device__ void atomicAdd(double *address, double val) {
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
} while (assumed != old);
}
#elif !defined(__CUDA_ARCH__) && (CUDA_VERSION < 8000)
// This needs to be defined for the host side pass
static inline __device__ void atomicAdd(double *address, double val) { }
#endif

#endif // THC_ATOMICS_INC
21 changes: 16 additions & 5 deletions torch/lib/THC/THCGeneral.c
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include "THCGeneral.h"
#include "TH.h"
#include "THCAllocator.h"
#include "THCBlas.h"
#include "THCCachingHostAllocator.h"
#include "THCStream.h"
#include "THCThreadLocal.h"
Expand All @@ -10,7 +9,12 @@
#include <stdint.h>

/* Size of scratch space available in global memory per each SM + stream */
#define GLOBAL_SCRATCH_SPACE_PER_SM_STREAM 4 * sizeof(float)
#define MIN_GLOBAL_SCRATCH_SPACE_PER_SM_STREAM 4 * sizeof(float)

/* Minimum amount of scratch space per device. Total scratch memory per
* device is either this amount, or the # of SMs * the space per SM defined
* above, whichever is greater.*/
#define MIN_GLOBAL_SCRATCH_SPACE_PER_DEVICE 32768 * sizeof(float)

THCCudaResourcesPerDevice* THCState_getDeviceResourcePtr(
THCState *state, int device);
Expand Down Expand Up @@ -108,9 +112,15 @@ void THCudaInit(THCState* state)
res->streams[0] = NULL;

/* The scratch space that we want to have available per each device is
based on the number of SMs available per device */
based on the number of SMs available per device. We guarantee a
minimum of 128kb of space per device, but to future-proof against
future architectures that may have huge #s of SMs, we guarantee that
we have at least 16 bytes for each SM. */
int numSM = state->deviceProperties[i].multiProcessorCount;
size_t sizePerStream = numSM * GLOBAL_SCRATCH_SPACE_PER_SM_STREAM;
size_t sizePerStream =
MIN_GLOBAL_SCRATCH_SPACE_PER_DEVICE >= numSM * MIN_GLOBAL_SCRATCH_SPACE_PER_SM_STREAM ?
MIN_GLOBAL_SCRATCH_SPACE_PER_DEVICE :
numSM * MIN_GLOBAL_SCRATCH_SPACE_PER_SM_STREAM;
res->scratchSpacePerStream = sizePerStream;
}

Expand Down Expand Up @@ -753,7 +763,8 @@ void THCHeapUpdate(THCState *state, ptrdiff_t size) {
}
}

#undef GLOBAL_SCRATCH_SPACE_PER_SM_STREAM
#undef MIN_GLOBAL_SCRATCH_SPACE_PER_SM_STREAM
#undef MIN_GLOBAL_SCRATCH_SPACE_PER_DEVICE

#include "THCStorage.c"
#include "THCAllocator.c"
14 changes: 10 additions & 4 deletions torch/lib/THC/THCHalf.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "THCHalf.h"
#include "THCThrustAllocator.cuh"
#include <thrust/transform.h>
#include <thrust/execution_policy.h>

Expand All @@ -11,19 +12,21 @@ struct __float2halfOp {
};

void THCFloat2Half(THCState *state, half *out, float *in, ptrdiff_t len) {
THCThrustAllocator thrustAlloc(state);
thrust::transform(
#if CUDA_VERSION >= 7000
thrust::cuda::par.on(THCState_getCurrentStream(state)),
thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
#else
thrust::device,
#endif
in, in + len, out, __float2halfOp());
}

void THCHalf2Float(THCState *state, float *out, half *in, ptrdiff_t len) {
THCThrustAllocator thrustAlloc(state);
thrust::transform(
#if CUDA_VERSION >= 7000
thrust::cuda::par.on(THCState_getCurrentStream(state)),
thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
#else
thrust::device,
#endif
Expand Down Expand Up @@ -58,14 +61,17 @@ float THC_half2float(half h)

int temp = ((sign << 31) | (exponent << 23) | mantissa);

return *((float*)((void*)&temp));
float x;
memcpy(&x,&temp,sizeof(float));
return x;
}

half THC_float2half(float f)
{
half ret;

unsigned x = *((int*)(void*)(&f));
unsigned x;
memcpy(&x,&f,sizeof(f));
unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1;
unsigned sign, exponent, mantissa;

Expand Down
5 changes: 5 additions & 0 deletions torch/lib/THC/THCReduceAll.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ inline ptrdiff_t getTwoPassBlocks(THCState* state, ptrdiff_t elements) {
THCState_getCurrentDeviceScratchSpaceSize(state) / sizeof(AccT);
THAssert(scratchSpace > 0);

// Limit to 1024 due to dimensionality constraint
if (scratchSpace > 1024) {
scratchSpace = 1024;
}

if (numBlocks > scratchSpace) {
numBlocks = scratchSpace;
}
Expand Down
1 change: 1 addition & 0 deletions torch/lib/THC/THCStorage.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "THCStorage.h"

#include "THCThrustAllocator.cuh"
#include <thrust/device_ptr.h>
#include <thrust/fill.h>
#if CUDA_VERSION >= 7000
Expand Down
2 changes: 0 additions & 2 deletions torch/lib/THC/THCStorageCopy.c
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#include "THCStorageCopy.h"
#include "THCGeneral.h"

#include "THCHalf.h"
#include "THCTensorCopy.h"

#include "generic/THCStorageCopy.c"
Expand Down
4 changes: 0 additions & 4 deletions torch/lib/THC/THCTensorCopy.c
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
#include "THCTensorCopy.h"
#include "THCGeneral.h"
#include "THCTensor.h"
#include "THCCachingHostAllocator.h"

#include "THCHalf.h"

#include "generic/THCTensorCopy.c"
#include "THCGenerateAllTypes.h"
1 change: 1 addition & 0 deletions torch/lib/THC/THCTensorMasked.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "THCTensorCopy.h"
#include "THCApply.cuh"
#include "THCReduce.cuh"
#include "THCThrustAllocator.cuh"

#include <thrust/device_ptr.h>
#include <thrust/scan.h>
Expand Down
1 change: 1 addition & 0 deletions torch/lib/THC/THCTensorMath.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "THCApply.cuh"
#include "THCNumerics.cuh"
#include "THCTensorMath.cuh"
#include "THCThrustAllocator.cuh"

#include <thrust/copy.h>
#include <thrust/count.h>
Expand Down
76 changes: 76 additions & 0 deletions torch/lib/THC/THCTensorMath.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,80 @@ __global__ void THCTensor_copyToDiagonal(T* a, T* b, ptrdiff_t start, ptrdiff_t
}
}

#define CAT_ARRAY_BATCH_SIZE 1024
#define CAT_ARRAY_MAX_INPUT_DIMS 4

// Similar to any other IndexToOffset calculation for copying along a given dimension.
template <typename IndexType, int Dims>
struct CatArrIndexToOffset {
static inline __device__ IndexType compute(
const IndexType outputSize[Dims],
const IndexType outputStride[Dims],
const IndexType dimSize,
const unsigned int concatDim,
IndexType linearIndex) {
IndexType offset = 0;

#pragma unroll
for (int i = Dims - 1; i >= 1; --i) {
IndexType curDimSize = i == concatDim ? dimSize : outputSize[i];
IndexType nextDimIndex = linearIndex / curDimSize;
IndexType curDimIndex = linearIndex - curDimSize * nextDimIndex;
IndexType curDimOffset = curDimIndex * outputStride[i];
offset += curDimOffset;
linearIndex = nextDimIndex;
}

return offset + linearIndex * outputStride[0];
}
};

template <typename T, typename IndexType>
struct CatArrInputTensor {
T* input;
IndexType offset;
IndexType dimSize;
IndexType nElements;
};

template<typename IndexType, unsigned int MaxDims>
struct OutputTensorSizeStride {
IndexType outputSize[MaxDims];
IndexType outputStride[MaxDims];
};

/**
* Kernel used to concatenated grimDim.y tensors into an output tensor. Uses a grid-stride loop based off of
* the blockIdx.x, threadIdx.x for each input to copy each element from each input tensor into the output.
*
* output: base pointer to the storage associated with the output tensor
* inputs: GPU-allocated array of input metadata for each input to concatenate in the kernel
* os: the size/stride vectors for the output tensor
* concatDim: dimension along which we are concatenating
* dimStride: the stride of the output tensor at the concatDim
*
* The most important assumption made is that the input tensors are contiguous.
*/
template <typename T, typename IndexType, int Dims>
__global__ void CatArrayBatchedCopy(
T* output,
CatArrInputTensor<T, IndexType>* inputs,
OutputTensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> os,
const int concatDim,
IndexType dimStride) {
T* data = inputs[blockIdx.y].input;
IndexType offset = inputs[blockIdx.y].offset;
IndexType dimSize = inputs[blockIdx.y].dimSize;
IndexType nElements = inputs[blockIdx.y].nElements;
IndexType dataOffset = offset * dimStride;

for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
linearIndex < nElements;
linearIndex += gridDim.x * blockDim.x) {
IndexType elementOffset = CatArrIndexToOffset<IndexType, Dims>::compute(
os.outputSize, os.outputStride, dimSize, concatDim, linearIndex);
output[dataOffset + elementOffset] = data[linearIndex];
}
}

#endif
1 change: 1 addition & 0 deletions torch/lib/THC/THCTensorMathReduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "THCNumerics.cuh"
#include "THCReduce.cuh"
#include "THCReduceAll.cuh"
#include "THCThrustAllocator.cuh"
#include <thrust/functional.h>
#include <thrust/device_ptr.h>
#include <thrust/transform_reduce.h>
Expand Down
1 change: 1 addition & 0 deletions torch/lib/THC/THCTensorSort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "THCTensorCopy.h"
#include "THCTensorTypeUtils.cuh"

#include "THCThrustAllocator.cuh"
#include <thrust/device_ptr.h>
#include <thrust/sort.h>
#if CUDA_VERSION >= 7000
Expand Down