Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Perf] Shared Array APIs #5338

Closed
turbo0628 opened this issue Jul 6, 2022 · 7 comments · Fixed by #5583
Closed

[Perf] Shared Array APIs #5338

turbo0628 opened this issue Jul 6, 2022 · 7 comments · Fixed by #5583
Assignees
Labels
feature request Suggest an idea on this project
Milestone

Comments

@turbo0628
Copy link
Member

turbo0628 commented Jul 6, 2022

TL;DR

I would like to request a scratch pad (shared memory in CUDA) API that exhibits ~2x n-body simulation performance on CUDA GPUs, which is on-par with the best CUDA implementations.

The Problem

The Taichi kernel function for n-body simulation is demonstrated as follows:

import taichi as ti

nBodies = 3072
softening = 1e-9
dt = 0.01
block_dim=128

ti.init(arch=ti.cuda)

bodies = ti.field(shape=(nBodies * 4), dtype=ti.float32)
velocities = ti.field(shape=(nBodies, 3), dtype=ti.float32)
@ti.kernel
def bodyForce():
    block_dim = ti.static(128)
    ti.loop_config(block_dim=block_dim)
    # LOOP 1
    for i in range(nBodies):
        Fx = 0.0
        Fy = 0.0
        Fz = 0.0
        bx = bodies[i * 4 + 0]
        by = bodies[i * 4 + 1]
        bz = bodies[i * 4 + 2]
        # LOOP 2
        for j in range(nBodies / block_dim):
            # FETCH
            # LOOP BLOCK
            for t in range(block_dim / 8):
                # LOOP UNROLL
                for s in ti.static(range(8)):
                    #############
                    # LOADS FROM DRAM
                    dx = bodies[(j * block_dim + t * 8 + s) * 4 + 0] - bx
                    dy = bodies[(j * block_dim + t * 8 + s) * 4 + 1] - by
                    dz = bodies[(j * block_dim + t * 8 + s) * 4 + 2] - bz
                    #############
                    distSqr = dx * dx + dy * dy + dz * dz + softening
                    invDist = 1.0 / ti.sqrt(distSqr)
                    invDist3 = invDist * invDist * invDist
                    Fx += dx * invDist3
                    Fy += dy * invDist3
                    Fz += dz * invDist3
        velocities[i, 0] += dt * Fx
        velocities[i, 1] += dt * Fy
        velocities[i, 2] += dt * Fz

    for i in range(nBodies):
        bodies[i * 4 + 0] = bodies[i * 4 + 0] + velocities[i, 0] * dt
        bodies[i * 4 + 1] = bodies[i * 4 + 1] + velocities[i, 1] * dt
        bodies[i * 4 + 2] = bodies[i * 4 + 2] + velocities[i, 2] * dt

bodyForce()

We should note that the inner nBodies loop is factorized into three smaller loops: the LOOP 2, LOOP BLOCK and LOOP UNROLL. This is partial loop unrolling which helps to build a more GPU friendly instruction flow. But we are not going to dive into details here.

Ideally we want to fetch data from DRAM memory into scratch pad at FETCH, and replace the LOADS FROM DRAM with the local loads from scratch pad. The target code should be equivalent with the following CUDA kernel:

__global__
void bodyForce(float4 *p, float4 *v, float dt, int n) {
  int i = blockDim.x * blockIdx.x + threadIdx.x;
  // LOOP 1
  if (i < n) {
    float Fx = 0.0f; float Fy = 0.0f; float Fz = 0.0f;
    // LOOP 2
    for (int tile = 0; tile < gridDim.x; tile++) {
      // FETCH
      /*********************************/
       __shared__ float3 spos[BLOCK_SIZE];
      float4 tpos = p[tile * blockDim.x + threadIdx.x];
      spos[threadIdx.x] = make_float3(tpos.x, tpos.y, tpos.z);
      float4* spos = p + tile * blockDim.x + threadIdx.x;
      __syncthreads();
      /*********************************/

      // LOOP BLOCK
      for (int j = 0; j < BLOCK_SIZE; j++) {
        /*********************************/
        /* LOADS FROM SCRATCH PAD */
        float dx = spos[j].x - p[i].x;
        float dy = spos[j].y - p[i].y;
        float dz = spos[j].z - p[i].z;
        /********************************/
        float distSqr = dx*dx + dy*dy + dz*dz + SOFTENING;
        float invDist = rsqrtf(distSqr);
        float invDist3 = invDist * invDist * invDist;
        Fx += dx * invDist3; 
        Fy += dy * invDist3; 
        Fz += dz * invDist3;
      }
      __syncthreads();
    }
    v[i].x += dt*Fx;
    v[i].y += dt*Fy;
    v[i].z += dt*Fz;
  }
}

Note1: in the loading stage, all threads request to the same position in the shared memory, meaning that a broadcast is triggered (See official document). This is the reason why shared memory impl is so much faster than global memory for the n-body simulation.

Note 2: accessing shared memory with stride of 3 32-bit floats will not cause bank conflicts, see official docuement.

Currently we cannot easily implement this in Taichi. We need a set of new APIs to enable the optimization.

The Hack

In fact, most of the utilities are already present in the codebase. Therefore I have made a prototype that manipulates the CHI IR to demonstrate how this works in Taichi.

Local Array Allocation

In CUDA, the shared memory is allocated through the following statement:

__shared__ float3 spos[BLOCK_SIZE];

In current Taichi code base, by setting the bls_size in the OffloadedStmt, the following allocation code snippet will be invoked at codegen:

void create_bls_buffer(OffloadedStmt *stmt) {
    auto type = llvm::ArrayType::get(llvm::Type::getInt8Ty(*llvm_context),
                                     stmt->bls_size);
    bls_buffer = new GlobalVariable(
        *module, type, false, llvm::GlobalValue::ExternalLinkage, nullptr,
        "bls_buffer", nullptr, llvm::GlobalVariable::NotThreadLocal,
        3 /*addrspace=shared*/);
    bls_buffer->setAlignment(llvm::MaybeAlign(8));
  }

I think we need a statement for this, a specialized AllocaStmt should work.

Data fetch into scratch pad

CUDA code:

float4 tpos = p[tile * blockDim.x + threadIdx.x];
spos[threadIdx.x] = make_float3(tpos.x, tpos.y, tpos.z);

See this part for the hack.

          // Load value from field.
          // src_val = bodies[src_offset, 0/1/2]
          std::vector<Stmt *> global_load_index;

          global_load_index.push_back(fetch_block->push_back<BinaryOpStmt>(
              BinaryOpType::add, src_offset, fetch_block->push_back<ConstStmt>(
              TypedConstant((int32_t)(index_xyz * sizeof(float))))));

          auto src_base_ptr = fetch_block->push_back<GlobalPtrStmt>(
              global_ptr_stmt->snodes, global_load_index, false);
          auto src_val = fetch_block->push_back<GlobalLoadStmt>(src_base_ptr);

          // smem[3 * threadIdx.x + 0/1/2] = src_val
          auto offset_base = fetch_block->push_back<BinaryOpStmt>(
              BinaryOpType::mul, tid,
              fetch_block->push_back<ConstStmt>(TypedConstant(3)), tid);
          auto offset = fetch_block->push_back<BinaryOpStmt>(
              BinaryOpType::add, offset_base,
              fetch_block->push_back<ConstStmt>(TypedConstant(index_xyz)));
          auto local_offset = fetch_block->push_back<BinaryOpStmt>(
              BinaryOpType::mul, offset,
              fetch_block->push_back<ConstStmt>(
                  TypedConstant((int32)sizeof(float))));

          auto bls_ptr = fetch_block->push_back<BlockLocalPtrStmt>(
              local_offset, data_type);
          auto bls_store =
              fetch_block->push_back<GlobalStoreStmt>(bls_ptr, src_val);

where index_xyz is the 0/1/2 offset for the x y and z elements.

In this part, we first get the corresponding snode from actual GlobalPtrStmt, and made the new statements with the snode and threadIdx.x. Then we get the pointer to bls buffer via BlockLocalPtrStmt.

I think the BlockLocalPtrStmt doesn't make sense and should be removed. It can be replace with a GlobalPtrStmt or GlobalVariableStmt, with proper pass-in arguments.

Replace Load Statements

Finally, the statements defined in the code section in LOAD FROM DRAM should be replaced with loads to scratch pads.
See here

The final performance data is sound!

image

API Proposal

I'd like to write the following Taichi code for this optimization:

import taichi as ti

nBodies = 3072
softening = 1e-9
dt = 0.01
block_dim=128

ti.init(arch=ti.cuda)

bodies = ti.field(shape=(nBodies, 4), dtype=ti.float32)
velocities = ti.field(shape=(nBodies, 3), dtype=ti.float32)
@ti.kernel
def bodyForce():
    block_dim = ti.static(128)
    ti.loop_config(block_dim=block_dim)
    # LOOP 1
    for i in range(nBodies):
        Fx = 0.0
        Fy = 0.0
        Fz = 0.0
        bx = bodies[i, 0]
        by = bodies[i, 1]
        bz = bodies[i, 2]
        # LOOP 2
        for j in range(nBodies / block_dim):
            # FETCH
            pad = ti.simt.scratch_pad((block_dim, 3), dtype=ti.f32, block_local=True)
            pad[ti.simt.local_thread_idx(), 0] =  bodies[j * block_dim + ti.local_thread_idx(), 0]
            pad[ti.simt.local_thread_idx(), 1] =  bodies[j * block_dim + ti.local_thread_idx(), 1]
            pad[ti.simt.local_thread_idx(), 2] =  bodies[j * block_dim + ti.local_thread_idx(), 2]
            ti.simt.barrier()
            # LOOP BLOCK
            for t in range(block_dim / 8):
                # LOOP UNROLL
                for s in ti.static(range(8)):
                    #############
                    # LOADS FROM PAD
                    dx = pad[(t * 8 + s) * 3 + 0] - bx
                    dy = pad[(t * 8 + s) * 3 + 1] - by
                    dz = pad[(t * 8 + s) * 3 + 2] - bz
                    #############
                    distSqr = dx * dx + dy * dy + dz * dz + softening
                    invDist = 1.0 / ti.sqrt(distSqr)
                    invDist3 = invDist * invDist * invDist
                    Fx += dx * invDist3
                    Fy += dy * invDist3
                    Fz += dz * invDist3
        velocities[i, 0] += dt * Fx
        velocities[i, 1] += dt * Fy
        velocities[i, 2] += dt * Fz

bodyForce()

It seems that we only need to modify the AllocaStmt to invoke correct allocation code, and consider scratch pad in the coming Mat/Vec types.

@turbo0628 turbo0628 added the feature request Suggest an idea on this project label Jul 6, 2022
@bobcao3
Copy link
Collaborator

bobcao3 commented Jul 6, 2022

^ Please move the existing TLS & BLS code to alloca statements as well

@turbo0628
Copy link
Member Author

Can't wait to do that

@strongoier
Copy link
Contributor

I suggest not mixing scratch pad with ti.Matrix. Scratch pad is a pure storage concept for advanced GPU programmers, while ti.Matrix is mostly a mathematical concept for Taichi users. The ways of using them are completely different. IMO we can introduce something new like ti.SharedArray (reference: https://numba.readthedocs.io/en/stable/cuda/examples.html).

@turbo0628
Copy link
Member Author

turbo0628 commented Jul 8, 2022

In the context of this discussion, using ti.Matrix and ti.Vector is a demonstration for internal design rather than final user APIs. To be more specific, I hope to use the Mat/Vec type internally for the shared memory storage, since it has all the features I need to implement scratch pad. Meanwhile I think it good to expose a more clear name to users, or even a namespace for all optimization-oriented features (like ti.simt). The original proposal was ti.simt.scratch_pad((m, n), dtype=ti.f32) but after some discussions I found it identical with Mat APIs.

WDYT about using the Mat type internally? We can decide the final API exposed to users when things get ready

@strongoier
Copy link
Contributor

I think for demonstration it is fine. However, when it comes to merging code in, it may not be the best choice to complicate ti.Matrix even more. What features do you want on Python side? If the only requirement is subscript, IMO adding a new code path outside ti.Matrix is not that heavy.

@k-ye
Copy link
Member

k-ye commented Jul 8, 2022

+1. Matrix and vector are meant to be mathematical types. In this case, though, we need a "buffer storage" type, and a dedicated type for that (before unifying all the pointers).

@turbo0628
Copy link
Member Author

As there's already a scratch pad in the codebase, let's use the Shared Array instead!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request Suggest an idea on this project
Projects
Status: Done
Development

Successfully merging a pull request may close this issue.

4 participants