In [2]:
import torch
import numpy as np
import time
from pytorch_block_sparse import BlockSparseLinear

import torch_sparse
from torch_sparse import spmm

In [3]:
def create_sparse_tensor(size, sparsity):
    rows, cols = size
    num_elements = rows * cols
    num_non_zero = int(num_elements * (1 - sparsity))

    # Random indices for non-zero elements
    indices = np.random.choice(num_elements, num_non_zero, replace=False)
    indices = np.unravel_index(indices, (rows, cols))
    indices = torch.LongTensor(indices)

    # Random values for these indices
    values = torch.randn(num_non_zero)

    # Create sparse tensor
    return torch.sparse_coo_tensor(indices, values, size)

In [9]:
warmup_iterations = 10
total_iterations  = 100
num_tokens_list   = [1, 64, 512]
dmodel_list       = [1024, 4096, 4096*2]
sparsity_list     = [0.95, 0.99, 0.999]

result_dict = []

for dmodel in dmodel_list:
    dff_shared = dmodel * 3
    for num_tokens in num_tokens_list:
        for sparsity in sparsity_list:

            print(dmodel, num_tokens, sparsity)

            # Create a dense matrix (Weight)
            sparse_matrix_mat1 = create_sparse_tensor(size=(dff_shared,dmodel),
                                                        sparsity=sparsity).cuda().to(torch.int)
            dense_matrix_mat1 = sparse_matrix_mat1.to_dense().cuda().to(torch.int)

            # Create a sparse matrix (Activation)
            sparse_matrix_mat2 = create_sparse_tensor(size=(dmodel,num_tokens),
                                                sparsity=sparsity).cuda().to(torch.int)
            dense_matrix_mat2 = sparse_matrix_mat2.to_dense().cuda().to(torch.int)

            blocksparse_fc = BlockSparseLinear(dmodel, dff_shared, density=(1-sparsity)).to(torch.int)

            tmp = sparse_matrix_mat1.coalesce()
            index = tmp.indices()
            value = tmp.values()

            for i in range(warmup_iterations):
                result = torch.sparse.mm(dense_matrix_mat1, sparse_matrix_mat2)
            torch.cuda.synchronize()

            t1 = time.time_ns()
            for i in range(total_iterations):
                result = torch.sparse.mm(dense_matrix_mat1, sparse_matrix_mat2)
            torch.cuda.synchronize()
            t2 = time.time_ns()


            for i in range(warmup_iterations):
                result = torch.mm(dense_matrix_mat1, dense_matrix_mat2)
            torch.cuda.synchronize()

            t3 = time.time_ns()
            for i in range(total_iterations):
                result = torch.mm(dense_matrix_mat1, dense_matrix_mat2)
            torch.cuda.synchronize()
            t4 = time.time_ns()


            for i in range(warmup_iterations):
                result = torch.sparse.mm(sparse_matrix_mat1, sparse_matrix_mat2)
            torch.cuda.synchronize()

            t5 = time.time_ns()
            for i in range(total_iterations):
                result = torch.sparse.mm(sparse_matrix_mat1, sparse_matrix_mat2)
            torch.cuda.synchronize()
            t6 = time.time_ns()


            for i in range(warmup_iterations):
                result = blocksparse_fc(dense_matrix_mat1)
            torch.cuda.synchronize()

            t7 = time.time_ns()
            for i in range(total_iterations):
                result = blocksparse_fc(dense_matrix_mat1)
            torch.cuda.synchronize()
            t8 = time.time_ns()


            for i in range(warmup_iterations):
                result = torch_sparse.spmm(index, value, dff_shared, dmodel, dense_matrix_mat2)
            torch.cuda.synchronize()

            t9 = time.time_ns()
            for i in range(total_iterations):
                result = torch_sparse.spmm(index, value, dff_shared, dmodel, dense_matrix_mat2)
            torch.cuda.synchronize()
            t10 = time.time_ns()

            sparse_gemm = (t2-t1) / 1.0e6
            dense_gemm = (t4-t3) / 1.0e6
            spsp_gemm = (t6-t5) / 1.0e6
            bsp_gemm = (t8-t7) / 1.0e6
            tsp_gemm = (t10-t9) / 1.0e6

            result_dict.append({
                "dmodel" : dmodel,
                "dff_shared" : dff_shared,
                "tokens" : num_tokens,
                "sparsity" : sparsity,
                "dense_gemm" : dense_gemm,
                "sparse_gemm" : sparse_gemm,
                "spsp_gemm" : spsp_gemm,
                "bsp_gemm" : bsp_gemm,
                "tsp_gemm" : tsp_gemm,
            })

1024 1 0.95


TypeError: nn.Module.to only accepts floating point or complex dtypes, but got desired dtype=torch.int32

In [4]:
import pandas as pd

df = pd.DataFrame(result_dict)
display(df)

Unnamed: 0,dmodel,dff_shared,tokens,sparsity,dense_gemm,sparse_gemm,spsp_gemm,bsp_gemm,tsp_gemm
0,1024,3072,1,0.95,1.247041,31.783795,89.500512,17.792116,7.009113
1,1024,3072,1,0.99,1.223108,31.683953,115.858394,10.875904,7.034448
2,1024,3072,1,0.999,1.226122,17.448482,75.207081,9.259757,6.958705
3,1024,3072,64,0.95,4.618658,34.412385,83.896912,17.701424,24.01263
4,1024,3072,64,0.99,3.845554,21.337367,57.465929,10.941535,4.754367
5,1024,3072,64,0.999,3.799233,20.596239,56.819286,9.024782,4.433968
6,1024,3072,512,0.95,20.497444,56.634255,178.708913,17.788524,175.47163
7,1024,3072,512,0.99,20.563533,32.886396,65.741869,10.93228,37.156623
8,1024,3072,512,0.999,20.624803,23.651566,52.329464,9.02516,4.458977
9,4096,12288,1,0.95,13.510477,24.683024,274.088391,700.619937,27.739495
