/
labels.cu
62 lines (55 loc) · 1.87 KB
/
labels.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#include "labels.h"
namespace kmeans {
namespace detail {
struct cublas_state {
cublasHandle_t cublas_handle;
cublas_state() {
cublasStatus_t stat;
stat = cublasCreate(&cublas_handle);
if (stat != CUBLAS_STATUS_SUCCESS) {
std::cout << "CUBLAS initialization failed" << std::endl;
exit(1);
}
}
~cublas_state() {
cublasStatus_t stat;
stat = cublasDestroy(cublas_handle);
if (stat != CUBLAS_STATUS_SUCCESS) {
std::cout << "CUBLAS destruction failed" << std::endl;
exit(1);
}
}
};
cublas_state state;
void gemm(cublasOperation_t transa, cublasOperation_t transb,
int m, int n, int k, const float *alpha,
const float *A, int lda, const float *B, int ldb,
const float *beta,
float *C, int ldc) {
cublasStatus_t status = cublasSgemm(state.cublas_handle, transa, transb,
m, n, k, alpha,
A, lda, B, ldb,
beta,
C, ldc);
if (status != CUBLAS_STATUS_SUCCESS) {
std::cout << "Invalid Sgemm" << std::endl;
exit(1);
}
}
void gemm(cublasOperation_t transa, cublasOperation_t transb,
int m, int n, int k, const double *alpha,
const double *A, int lda, const double *B, int ldb,
const double *beta,
double *C, int ldc) {
cublasStatus_t status = cublasDgemm(state.cublas_handle, transa, transb,
m, n, k, alpha,
A, lda, B, ldb,
beta,
C, ldc);
if (status != CUBLAS_STATUS_SUCCESS) {
std::cout << "Invalid Dgemm" << std::endl;
exit(1);
}
}
}
}