In [1]:
!pip install nvcc4jupyter

Collecting nvcc4jupyter
  Downloading nvcc4jupyter-1.2.1-py3-none-any.whl.metadata (5.1 kB)
Downloading nvcc4jupyter-1.2.1-py3-none-any.whl (10 kB)
Installing collected packages: nvcc4jupyter
Successfully installed nvcc4jupyter-1.2.1


In [2]:
%load_ext nvcc4jupyter

Detected platform "Colab". Running its setup...
Source files will be saved in "/tmp/tmp42ijolwk".


In [3]:
%%writefile cuda_kernel_bf16.cu
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <cmath>
#include <iostream>
#include <vector>
#include <iomanip>
#include <cstdlib>

#define WARP_SIZE 32

// math functions according to the problem statement
__device__ __forceinline__ float fast_sin(float x) { return __sinf(x); }
__device__ __forceinline__ float fast_cos(float x) { return __cosf(x); }
__device__ __forceinline__ float fast_log(float x) { return __logf(x); }
__device__ __forceinline__ float fast_exp(float x) { return __expf(x); }

// using __nv_bloat16 vectors
struct __align__(8) bf16x4 {
    __nv_bfloat16 x, y, z, w;
};

__device__ __forceinline__ float bf2f(__nv_bfloat16 v) { return __bfloat162float(v); }
__device__ __forceinline__ __nv_bfloat16 f2bf(float v) { return __float2bfloat16(v); }

__global__ void transform_kernel_bf16(const bf16x4* __restrict__ input, bf16x4* __restrict__ output, float* __restrict__ global_sum,int n) {
    // shared memory
    extern __shared__ float shared_data[];

    const int tid = threadIdx.x;
    const int lane = tid % WARP_SIZE;
    const int warp_id = tid / WARP_SIZE;
    const int block_size = blockDim.x;
    const int block_start = blockIdx.x * block_size * 4;
    const int warp_start  = block_start + warp_id * 32;

    // staging input to shared memory
    for (int i = lane; i < 32; i += WARP_SIZE) {
        const int global_idx = warp_start + i;
        if (global_idx < n) {
            const int vec_idx  = global_idx / 4;
            const int elem_idx = global_idx % 4;
            bf16x4 data = input[vec_idx];
            float fval = 0.f;
            switch (elem_idx) {
                case 0: fval = bf2f(data.x); break;
                case 1: fval = bf2f(data.y); break;
                case 2: fval = bf2f(data.z); break;
                case 3: fval = bf2f(data.w); break;
            }
            shared_data[warp_id * 32 + i] = fval;
        }
    }
    __syncthreads();

    float warp_sin_sum = 0.0f;
    const int local_idx  = warp_id * 32 + lane;
    const int global_idx = warp_start + lane;

    if (global_idx < n) {
        const int pos_in_32 = lane;
        const int mod_4 = pos_in_32 % 4;
        const float current_val = shared_data[local_idx];
        float result = 0.0f;

        // applying transformations
        if (pos_in_32 < 16) {
            switch (mod_4) {
                case 0: result = fast_sin(current_val); break;
                case 1: result = fast_cos(current_val); break;
                case 2: result = (current_val > 0.0f) ? fast_log(current_val) : -INFINITY; break;
                case 3: result = fast_exp(current_val); break;
            }
        } else {
            const int ref_pos = pos_in_32 - 16;
            const float ref_val = shared_data[warp_id * 32 + ref_pos];
            float ref_result = 0.0f;

            switch (ref_pos % 4) {
                case 0: ref_result = fast_sin(ref_val); break;
                case 1: ref_result = fast_cos(ref_val); break;
                case 2: ref_result = (ref_val > 0.0f) ? fast_log(ref_val) : -INFINITY; break;
                case 3: ref_result = fast_exp(ref_val); break;
            }

            switch (mod_4) {
                case 0: result = ref_result * fast_sin(current_val); break;
                case 1: result = ref_result * fast_cos(current_val); break;
                case 2: result = (current_val > 0.0f) ? (ref_result * fast_log(current_val)) : -INFINITY; break;
                case 3: result = ref_result * fast_exp(current_val); break;
            }
        }
        shared_data[local_idx + block_size * 4] = result;

        // global sum calculation of sin terms according to the problem
        if (mod_4 == 0 && lane < 31) {
            const int cos_idx = local_idx + 1;
            float cos_val = shared_data[cos_idx + block_size * 4];
            if (cos_val > 0.5f) {
                warp_sin_sum += result;
            }
        }
    }
    __syncthreads();

    // back to bf16
    for (int i = lane; i < 32; i += WARP_SIZE) {
        const int gidx = warp_start + i;
        if (gidx < n) {
            const int vec_idx  = gidx / 4;
            const int elem_idx = gidx % 4;
            bf16x4* out_vec = &output[vec_idx];
            float result_val = shared_data[warp_id * 32 + i + block_size * 4];
            __nv_bfloat16 b = f2bf(result_val);
            switch (elem_idx) {
                case 0: out_vec->x = b; break;
                case 1: out_vec->y = b; break;
                case 2: out_vec->z = b; break;
                case 3: out_vec->w = b; break;
            }
        }
    }
    for (int offset = 16; offset > 0; offset >>= 1) {
        warp_sin_sum += __shfl_down_sync(0xFFFFFFFF, warp_sin_sum, offset);
    }
    if (lane == 0 && warp_sin_sum != 0.0f) {
        atomicAdd(global_sum, warp_sin_sum);
    }
}

void measure_performance_bf16(int n, int num_runs = 4) {
    bf16x4 *d_input, *d_output, *d_temp;
    float *d_global_sum;
    const int n_vec = (n + 3) / 4;
    const size_t vec_bytes = sizeof(bf16x4);
    const size_t data_bytes = n_vec * vec_bytes;
    const double logical_bytes = 2.0 * static_cast<double>(n) * sizeof(__nv_bfloat16);

    std::vector<bf16x4> h_input(n_vec);
    for (int i = 0; i < n_vec; i++) {
        float x = 0.1f + static_cast<float>(rand()) / (static_cast<float>(RAND_MAX/4.9f));
        float y = 0.1f + static_cast<float>(rand()) / (static_cast<float>(RAND_MAX/4.9f));
        float z = 0.1f + static_cast<float>(rand()) / (static_cast<float>(RAND_MAX/4.9f));
        float w = 0.1f + static_cast<float>(rand()) / (static_cast<float>(RAND_MAX/4.9f));
        h_input[i].x = __float2bfloat16(x);
        h_input[i].y = __float2bfloat16(y);
        h_input[i].z = __float2bfloat16(z);
        h_input[i].w = __float2bfloat16(w);
    }

    cudaMalloc(&d_input,  data_bytes);
    cudaMalloc(&d_output, data_bytes);
    cudaMalloc(&d_temp,   data_bytes);
    cudaMalloc(&d_global_sum, sizeof(float));

    cudaMemcpy(d_input, h_input.data(), data_bytes, cudaMemcpyHostToDevice);

    const int block_size = 512;
    const int grid_size  = (n_vec + block_size - 1) / block_size;
    const int shared_mem_size = 2 * block_size * 4 * sizeof(float);

    std::cout << "performance measurement (bf16)\n";
    std::cout << "array size: " << n << " (" << std::fixed << std::setprecision(2)
              << (static_cast<double>(n) * sizeof(__nv_bfloat16) / (1024.0 * 1024.0)) << " MB BF16)\n";
    std::cout << "block size: " << block_size << " threads\n";
    std::cout << "shared memory: " << (shared_mem_size / 1024.0) << " KB per block\n";

    cudaEvent_t start, stop; cudaEventCreate(&start); cudaEventCreate(&stop);

    float total_time_ms = 0.0f; float global_sum_result = 0.0f;

    std::cout << "\nrunning kernel " << num_runs << " time\n";
    for (int run = 0; run < num_runs; ++run) {
        float zero = 0.0f; cudaMemcpy(d_global_sum, &zero, sizeof(float), cudaMemcpyHostToDevice);

        cudaEventRecord(start);
        transform_kernel_bf16<<<grid_size, block_size, shared_mem_size>>>(d_input, d_output, d_global_sum, n);
        cudaEventRecord(stop);
        cudaEventSynchronize(stop);

        cudaError_t err = cudaGetLastError();
        if (err != cudaSuccess) { std::cout << "kernel error: " << cudaGetErrorString(err) << "\n"; continue; }

        float ms = 0.0f; cudaEventElapsedTime(&ms, start, stop); total_time_ms += ms;
        if (run == 0) { float tmp=0.f; cudaMemcpy(&tmp, d_global_sum, sizeof(float), cudaMemcpyDeviceToHost); global_sum_result = tmp; }
    }

    const float avg_ms = total_time_ms / num_runs;
    const double kernel_bw_GBps = (logical_bytes / (avg_ms / 1000.0)) / (1024.0*1024.0*1024.0);

    cudaEventRecord(start);
    cudaMemcpy(d_temp, d_input, data_bytes, cudaMemcpyDeviceToDevice);
    cudaEventRecord(stop);
    cudaEventSynchronize(stop);

    float memcpy_ms = 0.0f; cudaEventElapsedTime(&memcpy_ms, start, stop);
    const double memcpy_bw_GBps = (logical_bytes / (memcpy_ms / 1000.0)) / (1024.0*1024.0*1024.0);
    const double efficiency = (kernel_bw_GBps / memcpy_bw_GBps) * 100.0;

    std::cout << "\nperformance results (bf16):\n";
    std::cout << "kernel execution time: " << std::fixed << std::setprecision(3) << avg_ms << " ms\n";
    std::cout << "kernel bandwidth: " << std::fixed << std::setprecision(2) << kernel_bw_GBps << " GB/s\n";
    std::cout << "memory bandwidth: " << std::fixed << std::setprecision(2) << memcpy_bw_GBps << " GB/s\n";
    std::cout << "efficiency: " << std::fixed << std::setprecision(1) << efficiency << "%\n";
    std::cout << "global sin sum: " << std::scientific << std::setprecision(6) << global_sum_result << "\n";

    cudaEventDestroy(start); cudaEventDestroy(stop);
    cudaFree(d_input); cudaFree(d_output); cudaFree(d_temp); cudaFree(d_global_sum);
}

int main() {
    const int n = 100000000;
    const int num_runs = 4;
    srand(42);

    cudaDeviceProp prop; cudaGetDeviceProperties(&prop, 0);
    std::cout << "using google colab for gpu: " << prop.name << "\n";
    std::cout << "global memory: " << (prop.totalGlobalMem / (1024.0 * 1024.0 * 1024.0)) << " GB\n";
    std::cout << "shared memory per block: " << (prop.sharedMemPerBlock / 1024.0) << " KB\n\n";

    measure_performance_bf16(n, num_runs);
    return 0;
}

Writing cuda_kernel_bf16.cu


In [4]:
!nvcc -o cuda_kernel_bf16 cuda_kernel_bf16.cu -std=c++17 -arch=sm_80 -O3 -Xptxas -O3

In [5]:
!./cuda_kernel_bf16

using google colab for gpu: NVIDIA A100-SXM4-40GB
global memory: 39.5574 GB
shared memory per block: 48 KB

performance measurement (bf16)
array size: 100000000 (190.73 MB BF16)
block size: 512 threads
shared memory: 16.00 KB per block

running kernel 4 time

performance results (bf16):
kernel execution time: 1.899 ms
kernel bandwidth: 196.17 GB/s
memory bandwidth: 288.04 GB/s
efficiency: 68.1%
global sin sum: 1.008407e+05
