-
Notifications
You must be signed in to change notification settings - Fork 21.4k
/
CUDABlas.cpp
814 lines (757 loc) · 29.9 KB
/
CUDABlas.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
/*
Provides the implementations of CUDA BLAS function templates.
*/
#include <ATen/cuda/CUDABlas.h>
#include <ATen/cuda/Exceptions.h>
#define CUDABLAS_POSINT_CHECK(FD, X) \
TORCH_CHECK( \
(X > 0 && X <= INT_MAX), \
"at::cuda::blas::" #FD " argument " #X \
" must be positive and less than ", \
INT_MAX, \
" but got ", \
X)
#define CUDABLAS_NONNEGINT_CHECK(FD, X) \
TORCH_CHECK( \
(X >= 0 && X <= INT_MAX), \
"at::cuda::blas::" #FD " argument " #X \
" must be non-negative and less than ", \
INT_MAX, \
" but got ", \
X)
namespace {
static cublasOperation_t _cublasOpFromChar(char op) {
switch (op) {
case 'n':
case 'N':
return CUBLAS_OP_N;
case 't':
case 'T':
return CUBLAS_OP_T;
case 'c':
case 'C':
return CUBLAS_OP_C;
}
AT_ERROR(
"_cublasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`");
}
static void _cublasAdjustLdLevel2(int64_t m, int64_t n, int64_t* lda) {
// Note: leading dimensions generally are checked that they are > 0
// and at least as big the result requires (even if the value won't
// be used).
// Q: Why does Level3 check trans but this doesn't?
// A: In level 2, the sizes (m, n) specify the size of A
// (independent of trans value). In level 3. the sizes (m, n, k)
// specify the sizes of op(A), op(B) where op depend on trans
// values.
if (n <= 1)
*lda = std::max<int64_t>(m, 1);
}
static void _cublasAdjustLdLevel3(
char transa,
char transb,
int64_t m,
int64_t n,
int64_t k,
int64_t* lda,
int64_t* ldb,
int64_t* ldc) {
bool transa_ = ((transa == 't') || (transa == 'T'));
bool transb_ = ((transb == 't') || (transb == 'T'));
// Note: leading dimensions generally are checked that they are > 0
// and at least as big the result requires (even if the value won't
// be used).
if (n <= 1)
*ldc = std::max<int64_t>(m, 1);
if (transa_) {
if (m <= 1)
*lda = std::max<int64_t>(k, 1);
} else {
if (k <= 1)
*lda = std::max<int64_t>(m, 1);
}
if (transb_) {
if (k <= 1)
*ldb = std::max<int64_t>(n, 1);
} else {
if (n <= 1)
*ldb = std::max<int64_t>(k, 1);
}
}
} // anonymous namespace
namespace at {
namespace cuda {
namespace blas {
const char* _cublasGetErrorEnum(cublasStatus_t error) {
if (error == CUBLAS_STATUS_SUCCESS) {
return "CUBLAS_STATUS_SUCCESS";
}
if (error == CUBLAS_STATUS_NOT_INITIALIZED) {
return "CUBLAS_STATUS_NOT_INITIALIZED";
}
if (error == CUBLAS_STATUS_ALLOC_FAILED) {
return "CUBLAS_STATUS_ALLOC_FAILED";
}
if (error == CUBLAS_STATUS_INVALID_VALUE) {
return "CUBLAS_STATUS_INVALID_VALUE";
}
if (error == CUBLAS_STATUS_ARCH_MISMATCH) {
return "CUBLAS_STATUS_ARCH_MISMATCH";
}
if (error == CUBLAS_STATUS_MAPPING_ERROR) {
return "CUBLAS_STATUS_MAPPING_ERROR";
}
if (error == CUBLAS_STATUS_EXECUTION_FAILED) {
return "CUBLAS_STATUS_EXECUTION_FAILED";
}
if (error == CUBLAS_STATUS_INTERNAL_ERROR) {
return "CUBLAS_STATUS_INTERNAL_ERROR";
}
if (error == CUBLAS_STATUS_NOT_SUPPORTED) {
return "CUBLAS_STATUS_NOT_SUPPORTED";
}
#ifdef CUBLAS_STATUS_LICENSE_ERROR
if (error == CUBLAS_STATUS_LICENSE_ERROR) {
return "CUBLAS_STATUS_LICENSE_ERROR";
}
#endif
return "<unknown>";
}
/* LEVEL 3 BLAS FUNCTIONS */
#ifndef __HIP_PLATFORM_HCC__
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11200
#define cublasGemmStridedBatchedExFix cublasGemmStridedBatchedEx
#else
// Workaround for https://github.com/pytorch/pytorch/issues/45724
cublasStatus_t cublasGemmStridedBatchedExFix(cublasHandle_t &handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const void *alpha,
const void *A,
cudaDataType Atype,
int lda,
long long int strideA,
const void *B,
cudaDataType Btype,
int ldb,
long long int strideB,
const void *beta,
void *C,
cudaDataType Ctype,
int ldc,
long long int strideC,
int64_t batchCount,
cudaDataType computeType,
cublasGemmAlgo_t algo)
{
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major != 7) {
return cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, alpha, A, Atype, lda, strideA, B, Btype, ldb, strideB, beta, C, Ctype, ldc, strideC, batchCount, computeType, algo);
}
cublasStatus_t result;
constexpr int64_t split = 63 * 1024;
for(int64_t i = 0; i < batchCount; i += split) {
int64_t count = std::min<int64_t>(split, batchCount - i);
result = cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, alpha,
(char *)A + i * strideA * 2, Atype, lda, strideA,
(char *)B + i * strideB * 2, Btype, ldb, strideB,
beta,
(char *)C + i * strideC * 2, Ctype, ldc, strideC,
(int)count, computeType, algo);
TORCH_CUDABLAS_CHECK(result);
}
return result;
}
#endif
#endif
#define GEMM_CHECK_ARGVALUES(Dtype) \
do { \
CUDABLAS_NONNEGINT_CHECK(gemm<Dtype>, m); \
CUDABLAS_NONNEGINT_CHECK(gemm<Dtype>, n); \
CUDABLAS_NONNEGINT_CHECK(gemm<Dtype>, k); \
CUDABLAS_POSINT_CHECK(gemm<Dtype>, lda); \
CUDABLAS_POSINT_CHECK(gemm<Dtype>, ldb); \
CUDABLAS_POSINT_CHECK(gemm<Dtype>, ldc); \
} while (0)
#define BGEMM_CHECK_ARGVALUES(Dtype) \
do { \
CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, m); \
CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, n); \
CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, k); \
CUDABLAS_POSINT_CHECK(bgemm<Dtype>, lda); \
CUDABLAS_POSINT_CHECK(bgemm<Dtype>, ldb); \
CUDABLAS_POSINT_CHECK(bgemm<Dtype>, ldc); \
CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, num_batches); \
} while (0)
template <>
void bgemm<double>(CUDABLAS_BGEMM_ARGTYPES(double)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
BGEMM_CHECK_ARGVALUES(double);
TORCH_CUDABLAS_CHECK(cublasDgemmStridedBatched(
handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches));
}
template <>
void bgemm<float>(CUDABLAS_BGEMM_ARGTYPES(float)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
BGEMM_CHECK_ARGVALUES(float);
TORCH_CUDABLAS_CHECK(cublasSgemmStridedBatched(
handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches));
}
template <>
void bgemm<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
BGEMM_CHECK_ARGVALUES(c10::complex<double>);
TORCH_CUDABLAS_CHECK(cublasZgemmStridedBatched(
handle, opa, opb, m, n, k, reinterpret_cast<const cuDoubleComplex*>(&alpha), reinterpret_cast<const cuDoubleComplex*>(a),
lda, stridea, reinterpret_cast<const cuDoubleComplex*>(b), ldb, strideb, reinterpret_cast<const cuDoubleComplex*>(&beta),
reinterpret_cast<cuDoubleComplex*>(c), ldc, stridec, num_batches));
}
template <>
void bgemm<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
BGEMM_CHECK_ARGVALUES(c10::complex<float>);
TORCH_CUDABLAS_CHECK(cublasCgemmStridedBatched(
handle, opa, opb, m, n, k, reinterpret_cast<const cuComplex*>(&alpha), reinterpret_cast<const cuComplex*>(a),
lda, stridea, reinterpret_cast<const cuComplex*>(b), ldb, strideb, reinterpret_cast<const cuComplex*>(&beta),
reinterpret_cast<cuComplex*>(c), ldc, stridec, num_batches));
}
template <>
void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
BGEMM_CHECK_ARGVALUES(at::Half);
float falpha = alpha;
float fbeta = beta;
#ifdef __HIP_PLATFORM_HCC__
TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k,
(void*)&falpha, a, rocblas_datatype_f16_r, (int)lda, stridea,
b, rocblas_datatype_f16_r, (int)ldb, strideb,
(void*)&fbeta, c, rocblas_datatype_f16_r, (int)ldc, stridec,
c, rocblas_datatype_f16_r, (int)ldc, stridec,
(int) num_batches, rocblas_datatype_f32_r, rocblas_gemm_algo_standard,
0, 0));
#else
#if defined(CUDA_VERSION) && CUDA_VERSION < 11000
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#endif // CUDA_VERSION < 11000
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major >= 5){
TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedExFix(
handle, opa, opb, m, n, k,
(void*)(&falpha), a, CUDA_R_16F, lda, stridea,
b, CUDA_R_16F, ldb, strideb, (void*)(&fbeta),
c, CUDA_R_16F, ldc, stridec,
num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} else {
for (int64_t i = 0; i < num_batches; ++i) {
at::cuda::blas::gemm<at::Half>(
transa, transb,
m, n, k,
alpha, (a + i * stridea), lda,
(b + i * strideb), ldb, beta,
(c + i * stridec), ldc);
}
}
#if defined(CUDA_VERSION) && CUDA_VERSION < 11000
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
#endif // CUDA_VERSION < 11000
#endif // __HIP_PLATFORM_HCC__
}
#if defined(__HIP_PLATFORM_HCC__) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000
template <>
void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
BGEMM_CHECK_ARGVALUES(at::BFloat16);
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
float falpha = alpha;
float fbeta = beta;
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(prop->major >= 8, "BFloat16 bgemm in CUDA requires Ampere or later GPU");
TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedExFix(handle,
opa, opb, (int)m, (int)n, (int)k,
(void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea,
b, CUDA_R_16BF, (int)ldb, strideb,
(void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec,
(int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
#elif defined(__HIP_PLATFORM_HCC__)
TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k,
(void*)&falpha, a, rocblas_datatype_bf16_r, (int)lda, stridea,
b, rocblas_datatype_bf16_r, (int)ldb, strideb,
(void*)&fbeta, c, rocblas_datatype_bf16_r, (int)ldc, stridec,
c, rocblas_datatype_bf16_r, (int)ldc, stridec,
(int) num_batches, rocblas_datatype_f32_r, rocblas_gemm_algo_standard,
0, 0, NULL, NULL));
#else
TORCH_CHECK(false, "BFloat16 bgemm in CUDA requires Ampere or later GPU");
#endif // defined(CUDA_VERSION) && CUDA_VERSION >= 11000
}
#endif // __HIP_PLATFORM_HCC__
template <>
void gemm<double>(CUDABLAS_GEMM_ARGTYPES(double)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(double);
TORCH_CUDABLAS_CHECK(cublasDgemm(
handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
}
template <>
void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(float);
TORCH_CUDABLAS_CHECK(cublasSgemm(
handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
}
#if !defined(__HIP_PLATFORM_HCC__) || (defined(__HIP_PLATFORM_HCC__) && HIP_VERSION >= 210)
template <>
void gemm<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(c10::complex<double>);
TORCH_CUDABLAS_CHECK(cublasZgemm(
handle, opa, opb, m, n, k, reinterpret_cast<const cuDoubleComplex*>(&alpha), reinterpret_cast<const cuDoubleComplex*>(a),
lda, reinterpret_cast<const cuDoubleComplex*>(b), ldb, reinterpret_cast<const cuDoubleComplex*>(&beta),
reinterpret_cast<cuDoubleComplex*>(c), ldc));
}
#endif
#if !defined(__HIP_PLATFORM_HCC__) || (defined(__HIP_PLATFORM_HCC__) && HIP_VERSION >= 210)
template <>
void gemm<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(c10::complex<float>);
TORCH_CUDABLAS_CHECK(cublasCgemm(
handle, opa, opb, m, n, k, reinterpret_cast<const cuComplex*>(&alpha), reinterpret_cast<const cuComplex*>(a),
lda, reinterpret_cast<const cuComplex*>(b), ldb, reinterpret_cast<const cuComplex*>(&beta),
reinterpret_cast<cuComplex*>(c), ldc));
}
#endif
template <>
void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
float falpha = alpha;
float fbeta = beta;
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(at::Half);
#ifdef __HIP_PLATFORM_HCC__
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(
handle,
opa,
opb,
m,
n,
k,
&falpha,
a,
rocblas_datatype_f16_r,
lda,
b,
rocblas_datatype_f16_r,
ldb,
&fbeta,
c,
rocblas_datatype_f16_r,
ldc,
c,
rocblas_datatype_f16_r,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
0));
#else
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major >= 5) {
#if defined(CUDA_VERSION) && CUDA_VERSION < 11000
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#endif // CUDA_VERSION < 11000
TORCH_CUDABLAS_CHECK(cublasGemmEx(
handle,
opa,
opb,
m,
n,
k,
&falpha,
a,
CUDA_R_16F,
lda,
b,
CUDA_R_16F,
ldb,
&fbeta,
c,
CUDA_R_16F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DFALT_TENSOR_OP));
#if defined(CUDA_VERSION) && CUDA_VERSION < 11000
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
#endif // CUDA_VERSION < 11000
} else {
TORCH_CUDABLAS_CHECK(cublasSgemmEx(
handle,
opa,
opb,
m,
n,
k,
&falpha,
a,
CUDA_R_16F,
lda,
b,
CUDA_R_16F,
ldb,
&fbeta,
c,
CUDA_R_16F,
ldc));
}
#endif
}
#ifdef __HIP_PLATFORM_HCC__
template <>
void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
float falpha = alpha;
float fbeta = beta;
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(at::BFloat16);
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(
handle,
opa,
opb,
m,
n,
k,
&falpha,
a,
rocblas_datatype_bf16_r,
lda,
b,
rocblas_datatype_bf16_r,
ldb,
&fbeta,
c,
rocblas_datatype_bf16_r,
ldc,
c,
rocblas_datatype_bf16_r,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
0));
}
#endif
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
template <>
void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
float falpha = alpha;
float fbeta = beta;
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(at::BFloat16);
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major >= 8) {
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK(cublasGemmEx(
handle,
opa,
opb,
m,
n,
k,
&falpha,
a,
CUDA_R_16BF,
lda,
b,
CUDA_R_16BF,
ldb,
&fbeta,
c,
CUDA_R_16BF,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DFALT_TENSOR_OP));
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
} else {
TORCH_CHECK(false, "BFloat16 gemm in CUDA requires Ampere or later GPU");
}
}
#endif
/* LEVEL 2 BLAS FUNCTIONS */
#define GEMV_CHECK_ARGVALUES(Dtype) \
do { \
CUDABLAS_NONNEGINT_CHECK(gemv<Dtype>, m); \
CUDABLAS_NONNEGINT_CHECK(gemv<Dtype>, n); \
CUDABLAS_POSINT_CHECK(gemv<Dtype>, lda); \
CUDABLAS_POSINT_CHECK(gemv<Dtype>, incx); \
CUDABLAS_POSINT_CHECK(gemv<Dtype>, incy); \
} while (0)
#if !defined(__HIP_PLATFORM_HCC__) || (defined(__HIP_PLATFORM_HCC__) && HIP_VERSION >= 210)
template <>
void gemv<c10::complex<double>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<double>)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t op = _cublasOpFromChar(trans);
_cublasAdjustLdLevel2(m, n, &lda);
GEMV_CHECK_ARGVALUES(c10::complex<double>);
TORCH_CUDABLAS_CHECK(
cublasZgemv(handle, op, m, n, reinterpret_cast<const cuDoubleComplex*>(&alpha), reinterpret_cast<const cuDoubleComplex*>(a),
lda, reinterpret_cast<const cuDoubleComplex*>(x), incx, reinterpret_cast<const cuDoubleComplex*>(&beta),
reinterpret_cast<cuDoubleComplex*>(y), incy));
}
#endif
#if !defined(__HIP_PLATFORM_HCC__) || (defined(__HIP_PLATFORM_HCC__) && HIP_VERSION >= 210)
template <>
void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>)) {
// gemv is bw bound, and does not benefit from TF32. But the precision
// loss still happens on TF32. So we disable it here.
NoTF32Guard disable_tf32;
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t op = _cublasOpFromChar(trans);
_cublasAdjustLdLevel2(m, n, &lda);
GEMV_CHECK_ARGVALUES(c10::complex<float>);
TORCH_CUDABLAS_CHECK(
cublasCgemv(handle, op, m, n, reinterpret_cast<const cuComplex*>(&alpha), reinterpret_cast<const cuComplex*>(a),
lda, reinterpret_cast<const cuComplex*>(x), incx, reinterpret_cast<const cuComplex*>(&beta),
reinterpret_cast<cuComplex*>(y), incy));
}
#endif
template <>
void gemv<double>(CUDABLAS_GEMV_ARGTYPES(double)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t op = _cublasOpFromChar(trans);
_cublasAdjustLdLevel2(m, n, &lda);
GEMV_CHECK_ARGVALUES(double);
TORCH_CUDABLAS_CHECK(
cublasDgemv(handle, op, m, n, &alpha, a, lda, x, incx, &beta, y, incy));
}
template <>
void gemv<float>(CUDABLAS_GEMV_ARGTYPES(float)) {
// gemv is bw bound, and does not benefit from TF32. But the precision
// loss still happens on TF32. So we disable it here.
NoTF32Guard disable_tf32;
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t op = _cublasOpFromChar(trans);
_cublasAdjustLdLevel2(m, n, &lda);
GEMV_CHECK_ARGVALUES(float);
TORCH_CUDABLAS_CHECK(
cublasSgemv(handle, op, m, n, &alpha, a, lda, x, incx, &beta, y, incy));
}
template <>
void gemv<at::Half>(CUDABLAS_GEMV_ARGTYPES(at::Half)) {
// In general, cublas regards matrices as column-major.
// The cublasS/Dgemv usages in cuda::blas::gemv<float>/<double> above
// require that external blas::gemv callers obey the following convention:
//
// If "a" is row-major with shape (output, summed) in blas::gemv's caller,
// caller interprets it as column-major with shape (summed, output), passes
// summed and output respectively to our local vars m, n, and requests that cublas
// internally transpose ("trans") the column-major interpretation of a.
//
// There's no such thing as "cublasHalfgemv", so here we hack gemv with a gemm.
// However, we must allow the same calling convention, because the caller shouldn't
// have to swap args based on whether it's calling blas::gemv<at::Half> or <float>.
bool trans_bool = (_cublasOpFromChar(trans) != CUBLAS_OP_N);
if (trans_bool) {
std::swap(m, n);
}
// After swap, local vars m, n contain the output and summed sizes respectively,
// regardless of whether "a" was row-major or column-major in gemv<>'s caller.
// To handle the possibility incy > 1, interprets vector y as column-major matrix with one row
// (shape (1, output)) and leading dim incy.
// trans(a)*x would compute a matrix with one column (shape (output, 1)) which wouldn't match y.
// So instead, we interpret x similarly to y, as a column-major matrix with one row
// (shape (1, summed)) and leading dim incx. The gemm then carries out x*transpose(trans(a)) to
// produce a matrix with one row (shape (1, output)), matching y.
char trans_flipped = (trans_bool ? 'n' : 't');
gemm<at::Half>(
'n', trans_flipped, 1, m, n, alpha, x, incx, a, lda, beta, y, incy);
}
#if defined(__HIP_PLATFORM_HCC__) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000
template <>
void gemv<at::BFloat16>(CUDABLAS_GEMV_ARGTYPES(at::BFloat16)) {
bool trans_bool = (_cublasOpFromChar(trans) != CUBLAS_OP_N);
if (trans_bool) {
std::swap(m, n);
}
char trans_flipped = (trans_bool ? 'n' : 't');
gemm<at::BFloat16>(
'n', trans_flipped, 1, m, n, alpha, x, incx, a, lda, beta, y, incy);
}
#endif
/* LEVEL 1 BLAS FUNCTIONS */
template <>
void dot<double>(CUDABLAS_DOT_ARGTYPES(double)) {
TORCH_CUDABLAS_CHECK(cublasDdot(handle, n, x, incx, y, incy, result));
}
template <>
void dot<float>(CUDABLAS_DOT_ARGTYPES(float)) {
TORCH_CUDABLAS_CHECK(cublasSdot(handle, n, x, incx, y, incy, result));
}
template <>
void dot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>)) {
TORCH_CUDABLAS_CHECK(cublasZdotu(handle, n, reinterpret_cast<const cuDoubleComplex*>(x),
incx, reinterpret_cast<const cuDoubleComplex*>(y), incy,
reinterpret_cast<cuDoubleComplex*>(result)));
}
template <>
void dot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>)) {
TORCH_CUDABLAS_CHECK(cublasCdotu(handle, n, reinterpret_cast<const cuComplex*>(x),
incx, reinterpret_cast<const cuComplex*>(y), incy,
reinterpret_cast<cuComplex*>(result)));
}
template <>
void dot<at::Half>(CUDABLAS_DOT_ARGTYPES(at::Half)) {
#if CUDA_VERSION >= 8000
TORCH_CUDABLAS_CHECK(cublasDotEx(
handle,
n,
x,
CUDA_R_16F,
incx,
y,
CUDA_R_16F,
incy,
result,
CUDA_R_16F,
CUDA_R_32F));
#elif HIP_VERSION >= 210
TORCH_CUDABLAS_CHECK(rocblas_hdot(
handle,
n,
reinterpret_cast<const rocblas_half*>(x),
incx,
reinterpret_cast<const rocblas_half*>(y),
incy,
reinterpret_cast<rocblas_half*>(result)));
#else
AT_ERROR("Cublas_Hdot requires CUDA 8.0+");
#endif
}
template <>
void vdot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>)) {
TORCH_CUDABLAS_CHECK(cublasCdotc(handle, n, reinterpret_cast<const cuComplex*>(x),
incx, reinterpret_cast<const cuComplex*>(y), incy,
reinterpret_cast<cuComplex*>(result)));
}
template <>
void vdot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>)) {
TORCH_CUDABLAS_CHECK(cublasZdotc(handle, n, reinterpret_cast<const cuDoubleComplex*>(x),
incx, reinterpret_cast<const cuDoubleComplex*>(y), incy,
reinterpret_cast<cuDoubleComplex*>(result)));
}
// This guards blocks use of getrfBatched and getriBatched on platforms other than cuda
#ifdef CUDART_VERSION
template <>
void getrfBatched<double>(
int n, double** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize) {
auto handle = at::cuda::getCurrentCUDABlasHandle();
TORCH_CUDABLAS_CHECK(cublasDgetrfBatched(
handle, n, dA_array, ldda, ipiv_array, info_array, batchsize));
}
template <>
void getrfBatched<float>(
int n, float** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize) {
auto handle = at::cuda::getCurrentCUDABlasHandle();
TORCH_CUDABLAS_CHECK(cublasSgetrfBatched(
handle, n, dA_array, ldda, ipiv_array, info_array, batchsize));
}
template <>
void getriBatched<double>(
int n, double** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize, double** dC_array) {
auto handle = at::cuda::getCurrentCUDABlasHandle();
TORCH_CUDABLAS_CHECK(cublasDgetriBatched(
handle, n, dA_array, ldda, ipiv_array, dC_array, n, info_array, batchsize));
}
template <>
void getriBatched<float>(
int n, float** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize, float** dC_array) {
auto handle = at::cuda::getCurrentCUDABlasHandle();
TORCH_CUDABLAS_CHECK(cublasSgetriBatched(
handle, n, dA_array, ldda, ipiv_array, dC_array, n, info_array, batchsize));
}
#endif // CUDART_VERSION
} // namespace blas
} // namespace cuda
} // namespace at