/
MultiTensorApply.cuh
379 lines (344 loc) · 13.6 KB
/
MultiTensorApply.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
#pragma once
#include <ATen/core/Tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <vector>
namespace at::native {
namespace {
static constexpr int64_t kILP = 4;
static constexpr int64_t kChunkSize = 65536;
static constexpr int64_t kBlockSize = 512;
// TODO(crcrpar): Add `n>5` for `low prec params & their higher prec copy`
// TensorListMetadata has to be < 4KB - the limit for kernel launch argument
static constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
static constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
static constexpr int depth_to_max_tensors_scalarlist[5] = {96, 64, 48, 36, 30};
static constexpr int depth_to_max_tensors_scalarlist_of_complex_double[2] = {
72,
60};
template <typename T>
__device__ __forceinline__ bool is_aligned(T* p) {
return ((uint64_t)p) % (kILP * sizeof(T)) == 0;
}
template <typename T>
__device__ __forceinline__ void load_store(
T* dst,
T* src,
int64_t dst_offset,
int64_t src_offset) {
using LT = at::native::memory::aligned_vector<T, kILP>;
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}
template <int n>
struct TensorListMetadata {
const void* addresses[n][depth_to_max_tensors[n - 1]];
int64_t numel_for_tensor[depth_to_max_tensors[n - 1]];
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
int block_to_chunk[depth_to_max_blocks[n - 1]];
int start_tensor_this_launch;
};
template <typename scalar_vals_t, int n>
struct TensorListScalarListMetadata {
const void* addresses[n][depth_to_max_tensors_scalarlist[n - 1]];
int64_t numel_for_tensor[depth_to_max_tensors_scalarlist[n - 1]];
scalar_vals_t scalar_vals[depth_to_max_tensors_scalarlist[n - 1]];
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
int block_to_chunk[depth_to_max_blocks[n - 1]];
};
// note(mkozuki): `n` of 1&2 violate the limit of cuda kernel argument size of
// 4kb with `c10::complex<double>`
template <>
struct TensorListScalarListMetadata<c10::complex<double>, 1> {
const void* addresses[1]
[depth_to_max_tensors_scalarlist_of_complex_double[0]];
int64_t
numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[0]];
c10::complex<double>
scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[0]];
unsigned char block_to_tensor[depth_to_max_blocks[1 - 1]];
int block_to_chunk[depth_to_max_blocks[1 - 1]];
};
template <>
struct TensorListScalarListMetadata<c10::complex<double>, 2> {
const void* addresses[2]
[depth_to_max_tensors_scalarlist_of_complex_double[1]];
int64_t
numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[1]];
c10::complex<double>
scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[1]];
unsigned char block_to_tensor[depth_to_max_blocks[2 - 1]];
int block_to_chunk[depth_to_max_blocks[2 - 1]];
};
// NOTE(crcrpar): This is a conservative resolution to handle `state_steps`
// whose each element is `at::Tensor` of 1 element representing the number of
// `step`s called so far.
template <int n>
struct FusedOptimizerTensorListMetadata {
const void* addresses[n][depth_to_max_tensors[n - 1]];
int64_t numel_for_tensor[depth_to_max_tensors[n - 1]];
const void* state_steps_addresses[depth_to_max_tensors_scalarlist[n - 1]];
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
int block_to_chunk[depth_to_max_blocks[n - 1]];
int start_tensor_this_launch;
};
template <typename T, typename U, typename... ArgTypes>
C10_LAUNCH_BOUNDS_1(kBlockSize)
__global__ void multi_tensor_apply_kernel(
T tensorListMeta,
U callable,
ArgTypes... args) {
// Hand the chunk information to the user-supplied functor to process however
// it likes.
callable(kChunkSize, tensorListMeta, args...);
}
} // namespace
// multi_tensor_apply enables horizontal fusion across lists of tensors.
// For example, whereas you once had a for-loop of a + b = c, where a, b,
// and c are individual tensors in lists as, bs, and cs, you can now with
// fewer kernel launches compute as + bs = cs.
//
// You can also imagine bs to be a scalar list vs a tensor list.
//
// The function below takes in tensor lists, scalars, and a callable and
// chunks up the computation to launch as few kernels as possible by iterating
// through every "chunk" in every tensor (thus the nested for loops). In the
// simplest case, everything gets bundled into just one kernel launch, but
// due to blocksize constraints, we may need to launch multiple kernels.
// Each kernel launch is defined by one tensorListMeta construct, which we
// use to track and reset the necessary metadata for each launch.
template <int depth, typename scalar_T, typename T, typename... ArgTypes>
void multi_tensor_apply(
std::vector<std::vector<at::Tensor>>& tensor_lists,
at::ArrayRef<Scalar> scalars,
T callable,
ArgTypes... args) {
TORCH_CHECK(
tensor_lists.size() == depth,
"Number of tensor lists has to match the depth.");
const size_t n_tensors = tensor_lists[0].size();
using scalar_vals_t = typename T::opmath_t;
TensorListScalarListMetadata<scalar_vals_t, depth> tensorListMeta;
int loc_block_info = 0;
int loc_tensor_info = 0;
for (size_t t = 0; t < n_tensors; t++) {
// short-circuit to avoid adding empty tensors to tensorListMeta
if (tensor_lists[0][t].numel() == 0) {
continue;
}
tensorListMeta.scalar_vals[loc_tensor_info] = scalars[t].to<scalar_T>();
tensorListMeta.numel_for_tensor[loc_tensor_info] =
tensor_lists[0][t].numel();
for (int d = 0; d < depth; d++) {
tensorListMeta.addresses[d][loc_tensor_info] =
tensor_lists[d][t].const_data_ptr();
}
loc_tensor_info++;
// now we enter [chunking territory].
// we will launch a kernel when EITHER the blocks get filled up OR
// the tensors get filled up. There will always be at least one block
// per tensor since the zero-sized ones will not enter the loop, so
// the nested forloop within represents iterating through the chunks
// of a single tensor.
const auto numel = tensor_lists[0][t].numel();
const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0);
for (auto chunk = 0; chunk < chunks; chunk++) {
tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
tensorListMeta.block_to_chunk[loc_block_info] = chunk;
loc_block_info++;
// a tensor is not considered full unless all its chunks have been
// processed
const bool tensors_full =
(loc_tensor_info == depth_to_max_tensors_scalarlist[depth - 1] &&
chunk == chunks - 1);
const bool blocks_full =
(loc_block_info == depth_to_max_blocks[depth - 1]);
if (tensors_full || blocks_full) {
multi_tensor_apply_kernel<<<
loc_block_info,
kBlockSize,
0,
at::cuda::getCurrentCUDAStream()>>>(
tensorListMeta, callable, args...);
C10_CUDA_KERNEL_LAUNCH_CHECK();
// Reset.
loc_block_info = 0;
// all chunks have already been handled in the kernel
if (chunk == chunks - 1) {
loc_tensor_info = 0;
} else { // blocks were full and tensor chunks remain
tensorListMeta.numel_for_tensor[0] =
tensorListMeta.numel_for_tensor[loc_tensor_info - 1];
tensorListMeta.scalar_vals[0] =
tensorListMeta.scalar_vals[loc_tensor_info - 1];
for (int d = 0; d < depth; d++) {
tensorListMeta.addresses[d][0] =
tensorListMeta.addresses[d][loc_tensor_info - 1];
}
loc_tensor_info = 1;
}
}
}
}
// note: [finishing what we started]
// if there's remaining work to be done but the tensors/blocks aren't full
// yet we are at the end, submit the kernel to do the work!
if (loc_block_info != 0) {
multi_tensor_apply_kernel<<<
loc_block_info,
kBlockSize,
0,
at::cuda::getCurrentCUDAStream()>>>(tensorListMeta, callable, args...);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
template <int depth, typename T, typename... ArgTypes>
void multi_tensor_apply(
std::vector<std::vector<at::Tensor>>& tensor_lists,
T callable,
ArgTypes... args) {
TORCH_CHECK(
tensor_lists.size() == depth,
"Number of tensor lists has to match the depth.");
const size_t n_tensors = tensor_lists[0].size();
TensorListMetadata<depth> tensorListMeta;
tensorListMeta.start_tensor_this_launch = 0;
int loc_block_info = 0;
int loc_tensor_info = 0;
for (size_t t = 0; t < n_tensors; t++) {
// short-circuit to avoid adding empty tensors to tensorListMeta
if (tensor_lists[0][t].numel() == 0) {
continue;
}
tensorListMeta.numel_for_tensor[loc_tensor_info] =
tensor_lists[0][t].numel();
for (int d = 0; d < depth; d++) {
tensorListMeta.addresses[d][loc_tensor_info] =
tensor_lists[d][t].const_data_ptr();
}
loc_tensor_info++;
// see note: [chunking territory].
const auto numel = tensor_lists[0][t].numel();
const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0);
for (auto chunk = 0; chunk < chunks; chunk++) {
tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
tensorListMeta.block_to_chunk[loc_block_info] = chunk;
loc_block_info++;
const bool tensors_full =
(loc_tensor_info == depth_to_max_tensors[depth - 1] &&
chunk == chunks - 1);
const bool blocks_full =
(loc_block_info == depth_to_max_blocks[depth - 1]);
if (tensors_full || blocks_full) {
multi_tensor_apply_kernel<<<
loc_block_info,
kBlockSize,
0,
at::cuda::getCurrentCUDAStream()>>>(
tensorListMeta, callable, args...);
C10_CUDA_KERNEL_LAUNCH_CHECK();
// Reset.
loc_block_info = 0;
if (chunk == chunks - 1) {
loc_tensor_info = 0;
tensorListMeta.start_tensor_this_launch = t + 1;
} else {
tensorListMeta.numel_for_tensor[0] =
tensorListMeta.numel_for_tensor[loc_tensor_info - 1];
for (int d = 0; d < depth; d++) {
tensorListMeta.addresses[d][0] =
tensorListMeta.addresses[d][loc_tensor_info - 1];
}
loc_tensor_info = 1;
tensorListMeta.start_tensor_this_launch = t;
}
}
}
}
// see note: [finishing what we started]
if (loc_block_info != 0) {
multi_tensor_apply_kernel<<<
loc_block_info,
kBlockSize,
0,
at::cuda::getCurrentCUDAStream()>>>(tensorListMeta, callable, args...);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
template <int depth, typename T, typename... ArgTypes>
void multi_tensor_apply_for_fused_optimizer(
std::vector<std::vector<at::Tensor>>& tensor_lists,
at::TensorList state_steps,
T callable,
ArgTypes... args) {
TORCH_CHECK(
tensor_lists.size() == depth,
"Number of tensor lists has to match the depth");
const auto num_tensors = tensor_lists[0].size();
FusedOptimizerTensorListMetadata<depth> tensorListMeta;
int loc_block_info = 0;
int loc_tensor_info = 0;
for (const auto& tensor_index : c10::irange(num_tensors)) {
// short-circuit to avoid adding empty tensors to tensorListMeta
if (tensor_lists[0][tensor_index].numel() == 0) {
continue;
}
tensorListMeta.state_steps_addresses[loc_tensor_info] =
state_steps[tensor_index].const_data_ptr();
tensorListMeta.numel_for_tensor[loc_tensor_info] =
tensor_lists[0][tensor_index].numel();
for (const auto& d : c10::irange(depth)) {
tensorListMeta.addresses[d][loc_tensor_info] =
tensor_lists[d][tensor_index].const_data_ptr();
}
loc_tensor_info++;
// see above note: [chunking territory]
const auto numel = tensor_lists[0][tensor_index].numel();
const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0);
TORCH_CHECK(chunks > -1);
for (const auto& chunk : c10::irange(chunks)) {
tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
tensorListMeta.block_to_chunk[loc_block_info] = chunk;
loc_block_info++;
const auto tensor_full =
(loc_tensor_info == depth_to_max_tensors[depth - 1] &&
chunk == chunks - 1);
const auto blocks_full = loc_block_info == depth_to_max_blocks[depth - 1];
if (tensor_full || blocks_full) {
multi_tensor_apply_kernel<<<
loc_block_info,
kBlockSize,
0,
at::cuda::getCurrentCUDAStream()>>>(
tensorListMeta, callable, args...);
C10_CUDA_KERNEL_LAUNCH_CHECK();
// Reset.
loc_block_info = 0;
if (chunk == chunks - 1) {
loc_tensor_info = 0;
} else {
tensorListMeta.numel_for_tensor[0] =
tensorListMeta.numel_for_tensor[loc_tensor_info - 1];
tensorListMeta.state_steps_addresses[0] =
tensorListMeta.state_steps_addresses[loc_tensor_info - 1];
for (const auto& d : c10::irange(depth)) {
tensorListMeta.addresses[d][0] =
tensorListMeta.addresses[d][loc_tensor_info - 1];
}
loc_tensor_info = 1;
}
}
}
}
// see above note: [finishing what we've started]
if (loc_block_info != 0) {
multi_tensor_apply_kernel<<<
loc_block_info,
kBlockSize,
0,
at::cuda::getCurrentCUDAStream()>>>(tensorListMeta, callable, args...);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
} // namespace at::native