forked from horovod/horovod
-
Notifications
You must be signed in to change notification settings - Fork 0
/
cuda_kernels.cu
191 lines (164 loc) · 7.89 KB
/
cuda_kernels.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
// Copyright (C) 2020 NVIDIA CORPORATION. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "cuda_kernels.h"
#include <stdexcept>
#include <cuda_fp16.h>
namespace horovod {
namespace common {
template<typename T, int blocks_per_copy>
__device__ void batched_memcpy_d(size_t idx, const void* in, void* out, size_t size) {
const T* input = reinterpret_cast<const T *>(in);
T* output = reinterpret_cast<T *>(out);
const size_t num_elements = size / sizeof(T);
for (size_t i = idx; i < num_elements; i += blockDim.x * blocks_per_copy) {
output[i] = input[i];
}
// Deal with any remaining bytes
size_t remainder = size % sizeof(T);
if (remainder > 0 && idx < remainder) {
const unsigned char* input_r = reinterpret_cast<const unsigned char *>(input + num_elements);
unsigned char* output_r = reinterpret_cast<unsigned char *>(output + num_elements);
output_r[idx] = input_r[idx];
}
}
template<int blocks_per_copy>
__global__ void batched_memcpy_k(BatchedD2DParams params) {
const size_t idx = blockDim.x * (blockIdx.x % blocks_per_copy) + threadIdx.x;
const size_t size = params.sizes[blockIdx.x / blocks_per_copy];
const void* input = params.in[blockIdx.x / blocks_per_copy];
void* output = params.out[blockIdx.x / blocks_per_copy];
// Check alignment relative to 16 bytes
size_t align_in = reinterpret_cast<size_t>(input) % BATCHED_D2D_PADDING;
size_t align_out = reinterpret_cast<size_t>(output) % BATCHED_D2D_PADDING;
// Select load/store size based on the misaligned buffer
size_t align = (align_out == 0) ? align_in : align_out;
if (align_in && align_out) {
// If both are misaligned, use unsigned char (this should not occur
// as fusion buffer locations should be aligned by applying BATCH_D2D_PADDING
// during construction.)
align = 1;
}
if (align % 16 == 0) {
batched_memcpy_d<ulonglong2, blocks_per_copy>(idx, input, output, size);
} else if (align % 8 == 0) {
batched_memcpy_d<unsigned long long, blocks_per_copy>(idx, input, output, size);
} else if (align % 4 == 0) {
batched_memcpy_d<unsigned int, blocks_per_copy>(idx, input, output, size);
} else if (align % 2 == 0) {
batched_memcpy_d<unsigned short, blocks_per_copy>(idx, input, output, size);
} else {
batched_memcpy_d<unsigned char, blocks_per_copy>(idx, input, output, size);
}
}
#define NTHREADS_D2D_KERNEL 1024
#define BLOCKS_PER_COPY_D2D_KERNEL 8
void BatchedD2DMemcpyCudaImpl(BatchedD2DParams& params, int num_copies, cudaStream_t stream)
{
batched_memcpy_k<BLOCKS_PER_COPY_D2D_KERNEL><<<num_copies * BLOCKS_PER_COPY_D2D_KERNEL,
NTHREADS_D2D_KERNEL, 0, stream>>>(params);
}
template<typename T, typename TS>
__global__ void scale_buffer_k(const T* input, T* output, int64_t num_elements, const TS scale_factor) {
const size_t idx = static_cast<size_t>(blockDim.x) * blockIdx.x + threadIdx.x;
for (size_t i = idx; i < num_elements; i += gridDim.x * blockDim.x) {
output[i] = scale_factor * input[i];
}
}
// Specialization for half2
__global__ void scale_buffer_half2_k(const __half* input, __half* output, int64_t num_elements, const __half scale_factor) {
const size_t idx = static_cast<size_t>(blockDim.x) * blockIdx.x + threadIdx.x;
#if __CUDA_ARCH__ > 530
const __half2* input_h2 = reinterpret_cast<const __half2 *>(input);
__half2* output_h2 = reinterpret_cast<__half2 *>(output);
const __half2 scale_factor_h2 = __halves2half2(scale_factor, scale_factor);
for (size_t i = idx; i < num_elements / 2; i += gridDim.x * blockDim.x) {
output_h2[i] = __hmul2(scale_factor_h2, input_h2[i]);
}
// Deal with last element if num_elements is odd
if (idx == 0 && num_elements % 2) {
output[num_elements - 1] = __hmul(scale_factor, input[num_elements - 1]);
}
#else
for (size_t i = idx; i < num_elements; i += gridDim.x * blockDim.x) {
output[i] = __float2half(__half2float(scale_factor) * __half2float(input[i]));
}
#endif
}
// Specialization for architectures without __half compute
template<>
__global__ void scale_buffer_k(const __half* input, __half* output, int64_t num_elements, const __half scale_factor) {
const size_t idx = static_cast<size_t>(blockDim.x) * blockIdx.x + threadIdx.x;
#if __CUDA_ARCH__ > 530
for (size_t i = idx; i < num_elements; i += gridDim.x * blockDim.x) {
output[i] = scale_factor * input[i];
}
#else
for (size_t i = idx; i < num_elements; i += gridDim.x * blockDim.x) {
output[i] = __float2half(__half2float(scale_factor) * __half2float(input[i]));
}
#endif
}
#define NTHREADS_SCALE_BUFFER_KERNEL 512
void ScaleBufferCudaImpl(const void* fused_input_data, void* buffer_data, const int64_t num_elements, double scale_factor,
DataType dtype, cudaStream_t stream) {
const int64_t blocks = (num_elements + NTHREADS_SCALE_BUFFER_KERNEL - 1) / NTHREADS_SCALE_BUFFER_KERNEL;
const int threads = NTHREADS_SCALE_BUFFER_KERNEL;
switch (dtype) {
case HOROVOD_UINT8:
scale_buffer_k<<<blocks, threads, 0, stream>>>((const uint8_t*) fused_input_data, (uint8_t*) buffer_data,
num_elements, scale_factor);
break;
case HOROVOD_INT8:
scale_buffer_k<<<blocks, threads, 0, stream>>>((const int8_t*) fused_input_data, (int8_t*) buffer_data,
num_elements, scale_factor);
break;
case HOROVOD_INT32:
scale_buffer_k<<<blocks, threads, 0, stream>>>((const int32_t*) fused_input_data, (int32_t*) buffer_data,
num_elements, scale_factor);
break;
case HOROVOD_INT64:
scale_buffer_k<<<blocks, threads, 0, stream>>>((const int64_t*) fused_input_data, (int64_t*) buffer_data,
num_elements, scale_factor);
break;
case HOROVOD_FLOAT16:
{
__half scale_factor_half = __float2half((float) scale_factor);
if ((size_t) fused_input_data % 4 == 0 && (size_t) buffer_data % 4 == 0) {
// If alignment allows, use half2 specialized kernel
int64_t num_elements_h2 = (num_elements + 1) / 2;
int64_t blocks_h2 = (num_elements_h2 + NTHREADS_SCALE_BUFFER_KERNEL - 1) / NTHREADS_SCALE_BUFFER_KERNEL;
scale_buffer_half2_k<<<blocks_h2, threads, 0, stream>>>((const __half*) fused_input_data, (__half*) buffer_data,
num_elements, scale_factor_half);
} else {
scale_buffer_k<<<blocks, threads, 0, stream>>>((const __half*) fused_input_data, (__half*) buffer_data,
num_elements, scale_factor_half);
}
break;
}
case HOROVOD_FLOAT32:
scale_buffer_k<<<blocks, threads, 0, stream>>>((const float*) fused_input_data, (float*) buffer_data,
num_elements, (float) scale_factor);
break;
case HOROVOD_FLOAT64:
scale_buffer_k<<<blocks, threads, 0, stream>>>((const double*) fused_input_data, (double*) buffer_data,
num_elements, scale_factor);
break;
default:
throw std::logic_error("Type " + DataType_Name(dtype) +
" not supported by ScaleBufferCudaImpl.");
}
}
} // namespace common
} // namespace horovod