Skip to content

Commit

Permalink
compiles
Browse files Browse the repository at this point in the history
  • Loading branch information
mcarilli committed Jan 5, 2021
1 parent 44c17b2 commit 180e173
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions aten/src/ATen/native/cuda/Dropout.cu
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ fused_dropout_kernel_vec(at::cuda::detail::TensorInfo<scalar_t, IndexType> a,

accscalar_t pinv = accscalar_t(1)/p;

// Helps align the total number of times curand_uniform4 is called by each thread for the same totalElements
// in the vec=2 and vec=4 cases.
bool gridxvec_loop_state = 0;

float4 rand;

// Note: Vectorized loads means we'll stride each thread by an additional VEC factor, as we'll load VEC elements at a time
for (IndexType linearIndex = idx * VEC;
linearIndex < totalElements;
Expand All @@ -69,12 +75,20 @@ fused_dropout_kernel_vec(at::cuda::detail::TensorInfo<scalar_t, IndexType> a,
//curand_uniform_double was pure evil anyway, not doing what it promises, and there's nothing for halfs, so generate float for everything
// Note: need a new set of random values per 4 elements -- we'll handle VEC elements in this thread, so need ceil(VEC / 4)
// sets of rand.
float4 rand = curand_uniform4(&state);
if ((VEC == 4) || (gridxvec_loop_state == 0)) {
rand = curand_uniform4(&state);
} else {
rand.x = rand.z;
rand.y = rand.w;
gridxvec_loop_state ^= 1;
}

rand.x = rand.x < p;
rand.y = rand.y < p;
rand.z = rand.z < p;
rand.w = rand.w < p;
if (VEC == 4) {
rand.z = rand.z < p;
rand.w = rand.w < p;
}

// Note: We explicitly check for is_contiguous() before launching the vectorized kernel
// and replace IndexToOffset call with linearIndex to allow vectorization of NHWC (or other)
Expand Down

0 comments on commit 180e173

Please sign in to comment.