Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions bench/BenchUtils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <algorithm>
#include <random>
#include <type_traits>
#include <string.h>

#ifdef _OPENMP
#include <omp.h>
Expand Down Expand Up @@ -89,4 +90,37 @@ int fbgemm_get_thread_num() {
#endif
}

int parseArgumentInt(
int argc,
const char* argv[],
const char* arg,
int non_exist_val,
int def_val) {
int val = non_exist_val;
int arg_len = strlen(arg);
for(auto i = 1; i < argc; ++i) {
const char* ptr = strstr(argv[i], arg);
if (ptr) {
int res;
sscanf(ptr + arg_len, "%d", &res);
val = (*(ptr + arg_len - 1) == '=') ? res : def_val;
break;
}
}
return val;
}

bool parseArgumentBool(
int argc,
const char* argv[],
const char* arg,
bool def_val) {
for(auto i = 1; i < argc; ++i) {
const char* ptr = strstr(argv[i], arg);
if (ptr) {
return true;
}
}
return def_val;
}
} // namespace fbgemm
13 changes: 13 additions & 0 deletions bench/BenchUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,19 @@ void cache_evict(const T& vec) {
}
}

/**
* Parse application command line arguments
*
*/
int parseArgumentInt(
int argc,
const char* argv[],
const char* arg,
int non_exist_val,
int def_val);
bool parseArgumentBool(
int argc, const char* argv[], const char* arg, bool def_val);

/**
* @param Fn functor to execute
* @param Fe data eviction functor
Expand Down
186 changes: 91 additions & 95 deletions bench/FP16Benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ void test_xerbla(char* srname, const int* info, int){
printf("\nXERBLA(MKL Error) is called :%s: %d\n", srname, *info);
}

void performance_test(int num_instances, bool flush) {
void performance_test(
int num_instances, bool flush, int repetitions, bool is_mkl) {

#if defined(USE_MKL)
mkl_set_xerbla((XerblaEntry)test_xerbla);
Expand Down Expand Up @@ -217,77 +218,81 @@ void performance_test(int num_instances, bool flush) {
}

#if defined(USE_MKL)
// Gold via MKL sgemm
type = "MKL_FP32";
if (is_mkl) {
// Gold via MKL sgemm
type = "MKL_FP32";
#elif defined(USE_BLAS)
type = "BLAS_FP32";
type = "BLAS_FP32";
#else
type = "REF_FP32";
type = "REF_FP32";
#endif

ttot = measureWithWarmup(
[&]() {
int copy = num_instances == 1 ? 0 : fbgemm_get_thread_num();
ttot = measureWithWarmup(
[&]() {
int copy = num_instances == 1 ? 0 : fbgemm_get_thread_num();
for(int i = 0; i < repetitions; ++i) {
#if defined(USE_MKL) || defined(USE_BLAS)
cblas_sgemm(
CblasRowMajor,
CblasNoTrans,
CblasNoTrans,
m,
n,
k,
1.0,
A[copy].data(),
k,
Bt[copy].data(),
btran == matrix_op_t::NoTranspose ? kAligned : nAligned,
beta,
C_ref[copy].data(),
n);
cblas_sgemm(
CblasRowMajor,
CblasNoTrans,
CblasNoTrans,
m,
n,
k,
1.0,
A[copy].data(),
k,
Bt[copy].data(),
btran == matrix_op_t::NoTranspose ? kAligned : nAligned,
beta,
C_ref[copy].data(),
n);
#else
cblas_sgemm_ref(
matrix_op_t::NoTranspose,
btran,
m,
n,
k,
alpha,
A[copy].data(),
k,
B[copy].data(),
(btran == matrix_op_t::NoTranspose) ? n : k,
beta,
C_ref[copy].data(),
n);
cblas_sgemm_ref(
matrix_op_t::NoTranspose,
btran,
m,
n,
k,
alpha,
A[copy].data(),
k,
B[copy].data(),
(btran == matrix_op_t::NoTranspose) ? n : k,
beta,
C_ref[copy].data(),
n);
#endif
},
3,
NITER,
[&]() {
if (flush) {
int copy = num_instances == 1 ? 0 : fbgemm_get_thread_num();
cache_evict(A[copy]);
}
},
3,
NITER,
[&]() {
if (flush) {
int copy = num_instances == 1 ? 0 : fbgemm_get_thread_num();
cache_evict(A[copy]);
#if defined(USE_MKL) || defined(USE_BLAS)
cache_evict(Bt[copy]);
cache_evict(Bt[copy]);
#else
cache_evict(B[copy]);
cache_evict(B[copy]);
#endif
cache_evict(C_ref[copy]);
}
},
// Use OpenMP if num instances > 1
num_instances > 1);

gflops = nflops / ttot / 1e9;
gbs = nbytes / ttot / 1e9;
printf(
"\n%30s m = %5d n = %5d k = %5d Gflops = %8.4lf GBytes = %8.4lf\n",
type.c_str(),
m,
n,
k,
gflops,
gbs);
cache_evict(C_ref[copy]);
}
},
// Use OpenMP if num instances > 1
num_instances > 1);

gflops = nflops / ttot / 1e9;
gbs = nbytes / ttot / 1e9;
printf(
"\n%30s m = %5d n = %5d k = %5d Gflops = %8.4lf GBytes = %8.4lf\n",
type.c_str(),
m,
n,
k,
gflops * repetitions,
gbs * repetitions);
}

type = "FBP_" + std::string(typeid(btype).name());

Expand All @@ -306,15 +311,17 @@ void performance_test(int num_instances, bool flush) {
int num_threads = num_instances == 1 ? fbgemm_get_num_threads() : 1;
int tid = num_instances == 1 ? fbgemm_get_thread_num() : 0;

cblas_gemm_compute(
matrix_op_t::NoTranspose,
m,
A[copy].data(),
*Bp[copy],
beta,
C_fb[copy].data(),
tid,
num_threads);
for(int i = 0; i < repetitions; ++i) {
cblas_gemm_compute(
matrix_op_t::NoTranspose,
m,
A[copy].data(),
*Bp[copy],
beta,
C_fb[copy].data(),
tid,
num_threads);
}
},
3,
NITER,
Expand All @@ -336,27 +343,20 @@ void performance_test(int num_instances, bool flush) {
m,
n,
k,
gflops,
gbs);
gflops * repetitions,
gbs * repetitions);
}
}

int main(int argc, char** argv) {
int main(int argc, const char* argv[]) {
int num_instances = 1;
#ifdef _OPENMP
const char* inst = getenv("GEMMBENCH_NUM_INSTANCES");
int num_instances = 1;
if (inst != nullptr && *inst) {
num_instances = std::max(atoi(inst), num_instances);
}

for (auto i = 1; i < argc; ++i) {
static const char param[] = "--inst=";
const char* ptr = strstr(argv[i], param);
if (ptr) {
ptr += sizeof(param) - 1; // null terminated
num_instances = std::max(atoi(ptr), num_instances);
}
}
num_instances = parseArgumentInt(
argc, argv, "--inst=", num_instances, num_instances);
printf("Running %d instances\n", num_instances);
if (num_instances > 1) {
// Set-up execution for multi-instance mode
Expand All @@ -372,23 +372,19 @@ int main(int argc, char** argv) {
} else {
// When running single instance use OMP_NUM_THREADS to determine
// parallelism. Default behaviour is using a single thread.
// Use 1 thread unless OMP_NUM_THREADS is explicit set.
int num_threads = parseArgumentInt(
argc, argv, "--num_threads=", 1, 1);
const char* val = getenv("OMP_NUM_THREADS");
if (val == nullptr || !*val) {
omp_set_num_threads(1);
omp_set_num_threads(num_threads);
}
}

#endif

bool flush = true;
for (auto i = 1; i < argc; ++i) {
static const char param[] = "--no-flush";
const char* ptr = strstr(argv[i], param);
if (ptr) {
flush = false;
}
}
int repetitions = parseArgumentInt(argc, argv, "--repit=", 1, 1);
bool no_flush = parseArgumentBool(argc, argv, "--no-flush", false);
bool no_mkl = parseArgumentBool(argc, argv, "--no-mkl", false);

performance_test(num_instances, flush);
performance_test(num_instances, !no_flush, repetitions, !no_mkl);
}