Skip to content

Commit

Permalink
make cunn compile with msvc && fix compilation failure for linux/mac os
Browse files Browse the repository at this point in the history
  • Loading branch information
BTNC committed Oct 14, 2016
1 parent 073ba88 commit 2bd89f4
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 5 deletions.
8 changes: 8 additions & 0 deletions lib/THC/THCDeviceTensor-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ template <typename T, int Dim,
typename IndexT, template <typename U> class PtrTraits>
__host__ __device__
THCDeviceTensor<T, Dim, IndexT, PtrTraits>::
#ifdef _MSC_VER
THCDeviceTensor(DataPtrType data, const IndexT (&sizes)[Dim])
#else
THCDeviceTensor(DataPtrType data, const IndexT sizes[Dim])
#endif
: data_(data) {
thc_static_assert(Dim > 0);

Expand All @@ -46,7 +50,11 @@ template <typename T, int Dim,
typename IndexT, template <typename U> class PtrTraits>
__host__ __device__
THCDeviceTensor<T, Dim, IndexT, PtrTraits>::THCDeviceTensor(
#ifdef _MSC_VER
DataPtrType data, const IndexT (&sizes)[Dim], const IndexT (&strides)[Dim])
#else
DataPtrType data, const IndexT sizes[Dim], const IndexT strides[Dim])
#endif
: data_(data) {
thc_static_assert(Dim > 0);

Expand Down
9 changes: 9 additions & 0 deletions lib/THC/THCDeviceTensor.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,21 @@ class THCDeviceTensor {

/// Constructor that calculates strides with no padding
__host__ __device__ THCDeviceTensor(DataPtrType data,
#ifdef _MSC_VER
const IndexT (&sizes)[Dim]);
#else
const IndexT sizes[Dim]);
#endif

/// Constructor that takes arbitrary size/stride arrays
__host__ __device__ THCDeviceTensor(DataPtrType data,
#ifdef _MSC_VER
const IndexT (&sizes)[Dim],
const IndexT (&strides)[Dim]);
#else
const IndexT sizes[Dim],
const IndexT strides[Dim]);
#endif

/// Returns true if the two tensors are of the same dimensionality,
/// size and stride.
Expand Down
3 changes: 3 additions & 0 deletions lib/THC/THCGeneral.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@
#ifdef _WIN32
# ifdef THC_EXPORTS
# define THC_API THC_EXTERNC __declspec(dllexport)
# define THC_CLASS __declspec(dllexport)
# else
# define THC_API THC_EXTERNC __declspec(dllimport)
# define THC_CLASS __declspec(dllimport)
# endif
#else
# define THC_API THC_EXTERNC
# define THC_CLASS
#endif

#ifndef THAssert
Expand Down
4 changes: 2 additions & 2 deletions lib/THC/THCHalf.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ THC_API half THC_float2half(float a);
THC_API float THC_half2float(half a);

/* Check for native fp16 support on the current device (CC 5.3+) */
THC_EXTERNC int THC_nativeHalfInstructions(THCState *state);
THC_API int THC_nativeHalfInstructions(THCState *state);

/* Check for performant native fp16 support on the current device */
THC_EXTERNC int THC_fastHalfInstructions(THCState *state);
THC_API int THC_fastHalfInstructions(THCState *state);

#endif /* CUDA_HALF_TENSOR */

Expand Down
2 changes: 1 addition & 1 deletion lib/THC/THCTensorRandom.cu
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ __global__ void generate_log_normal(curandStateMtgp32 *state, int size, float *r
}
}

#define NUM_BLOCKS min(THCCeilDiv(size, (ptrdiff_t) BLOCK_SIZE), (ptrdiff_t) MAX_NUM_BLOCKS)
#define NUM_BLOCKS min((int)THCCeilDiv(size, (ptrdiff_t) BLOCK_SIZE), MAX_NUM_BLOCKS)
THC_API void THCudaTensor_uniform(THCState* state, THCudaTensor *self_, double a, double b)
{
THAssert(THCudaTensor_checkGPU(state, 1, self_));
Expand Down
2 changes: 1 addition & 1 deletion lib/THC/THCTensorTypeUtils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ struct TensorUtils {

#define TENSOR_UTILS(TENSOR_TYPE, DATA_TYPE, ACC_DATA_TYPE) \
template <> \
struct TensorUtils<TENSOR_TYPE> { \
struct THC_CLASS TensorUtils<TENSOR_TYPE> { \
typedef DATA_TYPE DataType; \
typedef ACC_DATA_TYPE AccDataType; \
\
Expand Down
2 changes: 1 addition & 1 deletion torch/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#endif

#ifdef _WIN32
# ifdef torch_EXPORTS
# ifdef cutorch_EXPORTS
# define TORCH_API TORCH_EXTERNC __declspec(dllexport)
# else
# define TORCH_API TORCH_EXTERNC __declspec(dllimport)
Expand Down

0 comments on commit 2bd89f4

Please sign in to comment.