Skip to content

YukeWang96/TC-GNN_ATC23

Repository files navigation

TC-GNN Artifact for USENIX ATC'23.

  • Cite this project and paper. DOI
@inproceedings{TC-GNN,
  title={TC-GNN: Bridging Sparse GNN Computation and Dense Tensor Cores on GPUs},
  author={Yuke Wang and Boyuan Feng and Zheng Wang and Guyue Huang and Yufei Ding},
  booktitle={USENIX Annual Technical Conference (ATC)},
  year={2023}
}

1. Clone this project.

git clone --recursive git@github.com:YukeWang96/TCGNN-Pytorch.git
  • Requirements:
  • Ubuntu 16.04+
  • gcc >= 7.5
  • cmake >= 3.14
  • CUDA >= 11.0 and nvcc >= 11.0
  • NVIDIA GPU with sm >= 80 (i.e., Ampere, like RTX3090).

2. Environment Setup.(Skip this if evaluation on provided server)

2.1 [Method-1] Install via Docker (Recommended).

  • Go to docker/
  • Run ./build.sh
  • Run ./launch.sh

2.2 [Method-2] Install via Conda.

  • 2.2.1 Install conda on system Toturial.
  • 2.2.2 Create a conda environment:
conda create -n env_name python=3.6
  • 2.2.3 Install Pytorch:
conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch -c conda-forge

or using pip [Note that make sure the pip you use is the pip from current conda environment. You can check this by which pip]

pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
conda install -c dglteam dgl-cuda11.0
pip install torch requests tqdm
pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cu111.html
pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.8.0+cu111.html
pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.8.0+cu111.html
pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.8.0+cu111.html
pip install torch-geometric

3. Install TC-GNN.

Go to TCGNN_conv/, then run

./0_build_tcgnn.sh

to install the TCGNN_conv modules with Pytorch binding. Note that this step is required for both Docker and Conda setup.

4. Download graph datasets.

Get the preprocessed datasets

wget https://storage.googleapis.com/graph_dataset/tcgnn-ae-graphs.tar.gz
tar -zxvf tcgnn-ae-graphs.tar.gz && rm -rf tcgnn-ae-graphs.tar.gz

5. Running TC-GNN in end-to-end model.

  • Go to project root directory.
  • ./0_run_tcgnn_model.shto run all TC-GNN experiments.
  • Check the results in 1_bench_gcn.csv and 1_bench_agnn.csv.

6. Running DGL baseline (Fig-6a).

  • Go to dgl_baseline/ directory.
  • ./0_run_dgl.shto run all dgl experiments.
  • Check the results in Fig_6a_dgl_gcn.csv and Fig_6a_dgl_agnn.csv.

7. Running PyG baseline (Fig-6b).

  • Go to pyg_baseline/ directory;
  • ./0_run_pyg.shto run all pyg experiments.
  • Check the results in Fig_6b_PyG_gcn.csv and Fig_6b_PyG_agnn.csv.

8. Running TC-GNN in single-kernel comparison.

  • Go to project root directory.
  • ./0_run_tcgnn_single_kernel.shto run TC-GNN single kernel experiments.
  • Check the results in 1_bench_gcn.csv and 1_bench_agnn.csv.

9. cuSPARSE-bSpMM Baseline (Fig-6c)

cd TCGNN-bSpmm/cusparse
./0_run_bSpMM.sh
  • Check the results in Fig_6c_cuSPARSE_bSpMM.csv.

10. Dense Tile Reduction (Fig-7).

python 3_cnt_TC_blk_SDDMM.py
python 3_cnt_TC_blk_SpMM.py
  • Check the results in 3_cnt_TC_blk_SDDMM.csv and 3_cnt_TC_blk_SDDMM.csv.

11. tSparse Baseline (Table-5, column-2) (running outside the docker).

cd TCGNN-tsparse/
./0_run_tSparse.sh
  • Check the results in Table_5_tSparse.csv.

12. Triton Baseline (Table-5, column-3) (running outside the docker in a triton conda env).

cd TCGNN-trition/python/bench
conda activate triton
./0_run_triton.sh
  • Check the results in 1_run_triton.csv.

13. Use TC-GNN as a Tool or Library for your project.

Building a new design based on TC-GNN is simple, there are only several steps:

13.1 Register a new PyTorch Operator.

  • Add a compilation entry in TCGNN.cpp under TCGNN_conv/. An example is shown below.

std::vector<torch::Tensor> spmm_forward(
torch::Tensor input,
torch::Tensor nodePointer,
torch::Tensor edgeList,
torch::Tensor blockPartition,
torch::Tensor edgeToColumn,
torch::Tensor edgeToRow
) {
CHECK_INPUT(input);
CHECK_INPUT(nodePointer);
CHECK_INPUT(edgeList);
CHECK_INPUT(blockPartition);
CHECK_INPUT(edgeToColumn);
CHECK_INPUT(edgeToRow);
int num_nodes = nodePointer.size(0) - 1;
int num_edges = edgeList.size(0);
int embedding_dim = input.size(1);
return spmm_forward_cuda(nodePointer, edgeList,
blockPartition, edgeToColumn, edgeToRow,
num_nodes, num_edges, embedding_dim,
input);
}

m.def("forward", &spmm_forward, "TC-GNN SPMM forward (CUDA)");

13.2 Build the C++ design based on our existing examples

  • Add the operator implementation in TCGNN_kernel.cpp file under TCGNN_conv/. An example is shown below.

std::vector<torch::Tensor> spmm_forward_cuda(
torch::Tensor nodePointer,
torch::Tensor edgeList,
torch::Tensor blockPartition,
torch::Tensor edgeToColumn,
torch::Tensor edgeToRow,
int num_nodes,
int num_edges,
int embedding_dim,
torch::Tensor input
)
{
auto output = torch::zeros_like(input);
const int num_row_windows = blockPartition.size(0);
const int WARPperBlock = WPB;
dim3 grid(num_row_windows, 1, 1);
dim3 block(WARP_SIZE, WARPperBlock, 1);
const int dimTileNum = (embedding_dim + BLK_H - 1) / BLK_H;
const int dynamic_shared_size = dimTileNum * BLK_W * BLK_H * sizeof(float); // dynamic shared memory.
spmm_forward_cuda_kernel<<<grid, block, dynamic_shared_size>>>(
nodePointer.data<int>(),
edgeList.data<int>(),
blockPartition.data<int>(),
edgeToColumn.data<int>(),
edgeToRow.data<int>(),
num_nodes,
num_edges,
embedding_dim,
input.data<float>(),
output.data<float>()
);
// check for error
cudaError_t error = cudaGetLastError();
if(error != cudaSuccess)
{
// print the CUDA error message and exit
printf("CUDA error: %s\n", cudaGetErrorString(error));
exit(-1);
}
return {output};
}

13.3 Build the CUDA kernel design based on our existing examples.

  • Add a CUDA kernel design in TCGNN_kernel.cuh. An example is shown below.

__global__ void spmm_forward_cuda_kernel(
const int * __restrict__ nodePointer, // node pointer.
const int *__restrict__ edgeList, // edge list.
const int *__restrict__ blockPartition, // number of TC_blocks (16x8) in each row_window.
const int *__restrict__ edgeToColumn, // eid -> col within each row_window.
const int *__restrict__ edgeToRow, // eid -> col within each row_window.
const int numNodes,
const int numEdges,
const int embedding_dim, // embedding dimension.
const float *__restrict__ input, // input feature matrix.
float *output // aggreAGNNed output feature matrix.
) {
const unsigned bid = blockIdx.x; // block_index == row_window_index
const unsigned wid = threadIdx.y; // warp_index handling multi-dimension > 16.
const unsigned laneid = threadIdx.x; // lanid of each warp.
const unsigned tid = threadIdx.y * blockDim.x + laneid; // threadid of each block.
const unsigned warpSize = blockDim.x; // number of threads per warp.
const unsigned threadPerBlock = blockDim.x * blockDim.y; // number of threads per block.
const unsigned dimTileNum = embedding_dim / BLK_H; // number of tiles along the dimension
const unsigned nIdx_start = bid * BLK_H; // starting nodeIdx of current row_window.
const unsigned nIdx_end = min((bid + 1) * BLK_H, numNodes); // ending nodeIdx of current row_window.
const unsigned eIdx_start = nodePointer[nIdx_start]; // starting edgeIdx of current row_window.
const unsigned eIdx_end = nodePointer[nIdx_end]; // ending edgeIdx of the current row_window.
const unsigned num_TC_blocks = blockPartition[bid]; // number of TC_blocks of the current row_window.
const unsigned dense_bound = numNodes * embedding_dim;
__shared__ float sparse_A[BLK_H * BLK_W]; // row-major sparse matrix shared memory store.
__shared__ int sparse_AToX_index[BLK_W]; // TC_block col to dense_tile row.
// __shared__ float dense_X[dimTileNum * BLK_W * BLK_H]; // column-major dense tile [dimTileNum, BLK_W, BLK_H]
extern __shared__ float dense_X[];
wmma::fragment<wmma::matrix_a, BLK_H, BLK_H, BLK_W, wmma::precision::tf32, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, BLK_H, BLK_H, BLK_W, wmma::precision::tf32, wmma::col_major> b_frag;
wmma::fragment<wmma::accumulator, BLK_H, BLK_H, BLK_W, float> acc_frag;
wmma::fill_fragment(acc_frag, 0.0f);
// Processing TC_blocks along the column dimension of Sparse A.
for (unsigned i = 0; i < num_TC_blocks; i++){
// Init A_colToX_row with dummy values.
if (tid < BLK_W){
sparse_AToX_index[tid] = numNodes + 1;
}
__syncthreads();
// Init sparse_A with zero values.
#pragma unroll
for (unsigned idx = tid; idx < BLK_W * BLK_H; idx += threadPerBlock){
sparse_A[idx] = 0;
}
// Init dense_X with zero values.
#pragma unroll
for (unsigned idx = tid; idx < dimTileNum * BLK_W * BLK_H; idx += threadPerBlock){
dense_X[idx] = 0;
}
// Initialize sparse_A by using BLK_H (16) threads from the warp-0.
// currently fetch all neighbors of the current nodes.
// then to see whether it can fit into current TC_block frame of column.
#pragma unroll
for (unsigned eIdx = eIdx_start + tid; eIdx < eIdx_end; eIdx += threadPerBlock){
unsigned col = edgeToColumn[eIdx];
if (i * BLK_W <= col && col < (i + 1) * BLK_W){ // if the edge in the current TC_block frame of column.
unsigned row_local = edgeToRow[eIdx] % BLK_H;
unsigned col_local = col % BLK_W;
sparse_A[row_local * BLK_W + col_local] = 1; // set the edge of the sparse_A.
sparse_AToX_index[col_local] = edgeList[eIdx]; // record the mapping from sparse_A colId to rowId of dense_X.
}
}
__syncthreads();
// Initialize dense_X by column-major store,
// Threads of a warp for fetching a dense_X.
// each warp identify by wid.
if (wid < dimTileNum)
#pragma unroll
for (unsigned idx = laneid; idx < BLK_W * BLK_H; idx += warpSize){
unsigned dense_rowIdx = sparse_AToX_index[idx % BLK_W]; // TC_block_col to dense_tile_row.
unsigned dense_dimIdx = idx / BLK_W; // dimIndex of the dense tile.
unsigned source_idx = dense_rowIdx * embedding_dim + wid * BLK_H + dense_dimIdx;
unsigned target_idx = wid * BLK_W * BLK_H + idx;
// boundary test.
if (source_idx >= dense_bound)
dense_X[target_idx] = 0;
else
dense_X[target_idx] = input[source_idx];
}
__syncthreads();
if (wid < dimTileNum)
{
wmma::load_matrix_sync(a_frag, sparse_A, BLK_W);
wmma::load_matrix_sync(b_frag, dense_X + wid * BLK_W * BLK_H, BLK_W);
#pragma unroll
for (unsigned t = 0; t < a_frag.num_elements; t++) {
a_frag.x[t] = wmma::__float_to_tf32(a_frag.x[t]);
}
#pragma unroll
for (unsigned t = 0; t < b_frag.num_elements; t++) {
b_frag.x[t] = wmma::__float_to_tf32(b_frag.x[t]);
}
// Perform the matrix multiplication.
wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag);
}
}
if (wid < dimTileNum)
// Store the matrix to output matrix.
// * Note * embeeding dimension should be padded divisible by BLK_H for output correctness.
wmma::store_matrix_sync(output + bid * BLK_H * embedding_dim + wid * BLK_H, acc_frag, embedding_dim, wmma::mem_row_major);
}

13.4 Launch the TCGNN docker and recompile,

  • The compiled exectuable will be located under build/.
cd docker 
./launch.sh
./0_build_tcgnn.sh

Reference.

  • Deep Graph Library
    Wang, Minjie, et al. Deep graph library: A graph-centric, highly-performant package for graph neural networks.. The International Conference on Learning Representations (ICLR), 2019.

  • Pytorch Geometric
    Fey, Matthias, and Jan Eric Lenssen. Fast graph representation learning with PyTorch Geometric. The International Conference on Learning Representations (ICLR), 2019.

  • ASpT
    Hong, Changwan, et al. Adaptive sparse tiling for sparse matrix multiplication. In Proceedings of the 24th Symposium on Principles and Practice of Parallel Programming (PPoPP), 2019.

  • tSparse
    Zachariadis, O., et. al. Accelerating Sparse Matrix-Matrix Multiplication with GPU Tensor Cores Computers & Electrical Engineering (2020).

  • cuSPARSELt
    NVIDIA. Exploiting NVIDIA Ampere Structured Sparsity with cuSPARSELt.