Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hipblasGemmEx does not match the CPU or ROCBlas results for int8 x int8 to int32 matrix multiplication #498

Closed
xinyazhang opened this issue Jun 20, 2022 · 3 comments
Assignees

Comments

@xinyazhang
Copy link

The minimal testing case has been attached as
igemm_all_in_one.cc.gz.
Can be compiled with g++ -I/usr/include/eigen3 -I/opt/rocm/include igemm_all_in_one.cc -Wl,-rpath,/opt/rocm/lib/ -L/opt/rocm/lib/ -lrocblas -lhipblas -lamdhip64 -o igemm_aiw

What is the expected behavior

  • "Result from hipblasGemmEx with i8 input" matches the output of "Result from Eigen with i32 data type" and "Result from rocblas_gemm_ex with i8 input"

What actually happens

  • "Result from hipblasGemmEx with i8 input" shows
    • -15 65 110 -65
    • -49 -79 -90 117
  • The reference output "Result from Eigen with i8 data type" and "Result from rocblas_gemm_ex with i32 input" both show
    • -55 16 89 -44
    • 122 -102 68 -39
  • Note the reference output matches the unit tests from onnxruntime.

How to reproduce

  • Run the unit tests

Environment

Hardware description
GPU gfx90a
CPU AMD EPYC 7542 32-Core Processor
Software version
ROCK modinfo amdgpu|grep version shows version: 5.13.20.22.10
ROCR v5.1.3
HCC HIP version: 5.1.20532-f592a741
Library v5.1.3
@amcamd
Copy link
Contributor

amcamd commented Aug 3, 2022

@xinyazhang Thank you for reporting this error. Fixes are in the following hipBLAS and rocBLAS commits:

  • hipBLAS: 5b1d8ff
  • rocBLAS: 335704afda003156408280bfc4d10bd2c59ff5f4

The line below is required before calling hipblasGemmEx:

hipblasSetInt8Datatype(handle, HIPBLAS_INT8_DATATYPE_INT8);

The code below shows the above call and verifies the result. Running the code
gives:

Result from rocblas_gemm_ex with i8 input
62, 74, 14, 63, 20, -85, 43, 93,
Result from hipblasGemmEx with i8 input
62, 74, 14, 63, 20, -85, 43, 93,

Without the call to hipblasSetInt8Datatype the incorrect result is:

Result from rocblas_gemm_ex with i8 input
62, 74, 14, 63, 20, -85, 43, 93,
Result from hipblasGemmEx with i8 input
-49, 147, -89, 34, 117, 20, -47, -2,

#include <stdio.h>
#include <iostream>
#include <stdint.h>
#define __HIP_PLATFORM_AMD__
#include <hipblas.h>
#include <rocblas.h>
#include <vector>

#define CHECK_HIP_ERROR(error)                    \
    if(error != hipSuccess)                       \
    {                                             \
        fprintf(stderr,                           \
                "Hip error: '%s'(%d) at %s:%d\n", \
                hipGetErrorString(error),         \
                error,                            \
                __FILE__,                         \
                __LINE__);                        \
        exit(EXIT_FAILURE);                       \
    }

#define CHECK_HIPBLAS_ERROR(error)                              \
    if(error != HIPBLAS_STATUS_SUCCESS)                         \
    {                                                           \
        fprintf(stderr, "rocBLAS error: ");                     \
        if(error == HIPBLAS_STATUS_NOT_INITIALIZED)             \
            fprintf(stderr, "HIPBLAS_STATUS_NOT_INITIALIZED");  \
        if(error == HIPBLAS_STATUS_ALLOC_FAILED)                \
            fprintf(stderr, "HIPBLAS_STATUS_ALLOC_FAILED");     \
        if(error == HIPBLAS_STATUS_INVALID_VALUE)               \
            fprintf(stderr, "HIPBLAS_STATUS_INVALID_VALUE");    \
        if(error == HIPBLAS_STATUS_MAPPING_ERROR)               \
            fprintf(stderr, "HIPBLAS_STATUS_MAPPING_ERROR");    \
        if(error == HIPBLAS_STATUS_EXECUTION_FAILED)            \
            fprintf(stderr, "HIPBLAS_STATUS_EXECUTION_FAILED"); \
        if(error == HIPBLAS_STATUS_INTERNAL_ERROR)              \
            fprintf(stderr, "HIPBLAS_STATUS_INTERNAL_ERROR");   \
        if(error == HIPBLAS_STATUS_NOT_SUPPORTED)               \
            fprintf(stderr, "HIPBLAS_STATUS_NOT_SUPPORTED");    \
        if(error == HIPBLAS_STATUS_INVALID_ENUM)                \
            fprintf(stderr, "HIPBLAS_STATUS_INVALID_ENUM");     \
        if(error == HIPBLAS_STATUS_UNKNOWN)                     \
            fprintf(stderr, "HIPBLAS_STATUS_UNKNOWN");          \
        fprintf(stderr, "\n");                                  \
        exit(EXIT_FAILURE);                                     \
    }

int main()
{
    int m = 2;
    int n = 4;
    int k = 4;
    int sizeA = m*k;
    int sizeB = k*n;
    int sizeC = m*n;

    std::vector<int8_t>  hA{-3, 7, 5, -6, 4, -5, 8, 7};
    std::vector<int8_t>  hB{5, -3,  7,  8, -6, -8, -3,  6, 7,  9,  9, -5, 8,  7, -6,  7};
    std::vector<int32_t> hC(sizeC);

    int32_t alpha = 1;
    int32_t beta = 0;
    int lda = m;
    int ldb = k;
    int ldc = m;

    int8_t *da, *db, *dc;
    // Note da and db are allocated large enough for a32 and b32
    CHECK_HIP_ERROR(hipMalloc(&da, sizeA      * sizeof(int8_t)));
    CHECK_HIP_ERROR(hipMalloc(&db, sizeB      * sizeof(int8_t)));
    CHECK_HIP_ERROR(hipMalloc(&dc, sizeC      * sizeof(int32_t)));

    // copy matrices from host to device
    // Let's try ROC BLAS
    {
        rocblas_handle handle;
        rocblas_create_handle(&handle);
        rocblas_operation transa = rocblas_operation_none;
        rocblas_operation transb = rocblas_operation_none;

        CHECK_HIP_ERROR(hipMemcpy(da, hA.data(), sizeof(int8_t) * sizeA, hipMemcpyHostToDevice));
        CHECK_HIP_ERROR(hipMemcpy(db, hB.data(), sizeof(int8_t) * sizeB, hipMemcpyHostToDevice));
        CHECK_HIP_ERROR(hipMemset(dc, 0, sizeof(int32_t) * sizeC ));

        rocblas_gemm_ex(handle,
                        transa, transb,
                        m, n, k,
                        &alpha,
                        da, rocblas_datatype_i8_r, lda,
                        db, rocblas_datatype_i8_r, ldb,
                        &beta,
                        dc, rocblas_datatype_i32_r, ldc,
                        dc, rocblas_datatype_i32_r, ldc, // C == D
                        rocblas_datatype_i32_r,
                        rocblas_gemm_algo_standard,
                        0, 0);
        CHECK_HIP_ERROR(hipMemcpy(hC.data(), dc, sizeof(int32_t) * sizeC, hipMemcpyDeviceToHost));
        std::cout << "Result from rocblas_gemm_ex with i8 input" << std::endl;
        for(int32_t i : hC){ std::cout << i << ", ";};
        std::cout << std::endl;

        rocblas_destroy_handle(handle);
    }

    // HIP BLAS
    hipblasHandle_t handle;
    CHECK_HIPBLAS_ERROR(hipblasCreate(&handle));

    // below line is needed to use int8_t datatype in place of packed_int8x4 datatype
       hipblasSetInt8Datatype(handle, HIPBLAS_INT8_DATATYPE_INT8);

    hipblasOperation_t transa = HIPBLAS_OP_N;
    hipblasOperation_t transb = HIPBLAS_OP_N;

    CHECK_HIP_ERROR(hipMemcpy(da, hA.data(), sizeof(int8_t) * sizeA, hipMemcpyHostToDevice));
    CHECK_HIP_ERROR(hipMemcpy(db, hB.data(), sizeof(int8_t) * sizeB, hipMemcpyHostToDevice));
    CHECK_HIP_ERROR(hipMemset(dc, 0, sizeof(int32_t) * sizeC));
    CHECK_HIPBLAS_ERROR(hipblasGemmEx(handle,
                                      transa,
                                      transb,
                                      m, n, k,
                                      &alpha,
                                      da, HIPBLAS_R_8I, lda,
                                      db, HIPBLAS_R_8I, ldb,
                                      &beta,
                                      dc, HIPBLAS_R_32I, ldc,
                                      HIPBLAS_R_32I,
                                      HIPBLAS_GEMM_DEFAULT));
    CHECK_HIP_ERROR(hipMemcpy(hC.data(), dc, sizeof(int32_t) * sizeC, hipMemcpyDeviceToHost));
    std::cout << "Result from hipblasGemmEx with i8 input" << std::endl;
    for(int32_t i : hC){ std::cout << i << ", ";};
    std::cout << std::endl;

    CHECK_HIP_ERROR(hipFree(da));
    CHECK_HIP_ERROR(hipFree(db));
    CHECK_HIP_ERROR(hipFree(dc));
    CHECK_HIPBLAS_ERROR(hipblasDestroy(handle));
    return 0;
}

@amcamd amcamd closed this as completed Oct 11, 2022
@jinz2014
Copy link

warning: 'hipblasSetInt8Datatype' is deprecated: "The hipblasSetInt8Datatype function will be removed in a future release and only int8_t datatype will be supported. packed_int8x4 datatype support will be removed." [-Wdeprecated-declarations]
hipblasSetInt8Datatype(handle, HIPBLAS_INT8_DATATYPE_INT8);

hipblasSetInt8Datatype is still required..

@jinz2014
Copy link

Is rocblas preferred over hipblas ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants