## RWKV TMix block python / triton benchmark

Benchmarking the various kernels

In [1]:
# Configure the parent path to be the proj folder
import sys, os, torch, time
sys.path.append('../../')

# Import the block classes
from rwkv_block.v7_goose.block.rwkv7_time_mix import RWKV7TimeMix

# File to load
MODEL_FILENAME="v7-1B5-world.pth"

# Run device, and run dtype to use
RUN_DEVICE="cpu"
RUN_DTYPE=torch.bfloat16

# Check for cuda device
if torch.cuda.is_available():
    RUN_DEVICE="cuda:0"

# Check if the reference weights exists
assert os.path.exists(f"./.model/{MODEL_FILENAME}"), "The reference weights does not exist. Please download it first (00-model-download.ipynb)"

# Loads the model weights
model_weight = torch.load(f"./.model/{MODEL_FILENAME}", map_location='cpu', weights_only=True, mmap=True)

# Model filename
print(f"### Model filename: {MODEL_FILENAME}")

# Lets get the hidden_size, and setup the test module
hidden_size = model_weight['emb.weight'].shape[1]
print(f"### Model hidden_size: {hidden_size}")

# # List the model weights keys, and their shapes
# print(f"### model weights keys:")
# for key in model_weight:
#     print(f"{key}: {model_weight[key].shape} - {model_weight[key].dtype}")

# Ensure cuda path is set, get the cuda nvcc path
os.environ['CUDA_HOME'] = "/usr/local/cuda"

### Model filename: v7-1B4.pth
### Model hidden_size: 2048


In [2]:
# Initialize the channelmix state, and x state to test
# NOTE: The triton kernel minimum chunk size is 16, it fallsback to pytorch mode otherwise
# we intentionally DO not use a unit of 16, so the remainder pytorch code kicks in for triton
IN_TOKENS_LEN=8192
x_state_0 = torch.ones(1, IN_TOKENS_LEN, hidden_size, device=RUN_DEVICE, dtype=RUN_DTYPE)
x_state_1 = torch.ones(1, IN_TOKENS_LEN, hidden_size, device=RUN_DEVICE, dtype=RUN_DTYPE)
x_state_2 = torch.ones(1, IN_TOKENS_LEN, hidden_size, device=RUN_DEVICE, dtype=RUN_DTYPE)
tmix_shift_0 = torch.ones(1, hidden_size, device=RUN_DEVICE, dtype=RUN_DTYPE)
tmix_shift_1 = torch.ones(1, hidden_size, device=RUN_DEVICE, dtype=RUN_DTYPE)
tmix_wkv_0 = torch.ones(1, hidden_size // 64, 64, 64, device=RUN_DEVICE, dtype=torch.float)
tmix_wkv_1 = torch.ones(1, hidden_size // 64, 64, 64, device=RUN_DEVICE, dtype=torch.float)

# Iteration to test
TEST_STEPS = 10

# Slower reference implementation
tmix_pytorch = RWKV7TimeMix({ "num_hidden_layers":24, "hidden_size":hidden_size, "layer_id":0, "device":RUN_DEVICE, "dtype":RUN_DTYPE, "tmix_backend":"pytorch_ref" })
tmix_pytorch.load_from_model_state_dict(model_weight, 0)

# Improved pytorch implement
tmix_pytorch_chunk = RWKV7TimeMix({ "num_hidden_layers":24, "hidden_size":hidden_size, "layer_id":0, "device":RUN_DEVICE, "dtype":RUN_DTYPE, "tmix_backend":"pytorch" })
tmix_pytorch_chunk.load_from_model_state_dict(model_weight, 0)

tmix_triton = RWKV7TimeMix({ "num_hidden_layers":24, "hidden_size":hidden_size, "layer_id":0, "device":RUN_DEVICE, "dtype":RUN_DTYPE, "tmix_backend":"triton" })
tmix_triton.load_from_model_state_dict(model_weight, 0)

tmix_triton_bighead = RWKV7TimeMix({ "num_hidden_layers":24, "hidden_size":hidden_size, "layer_id":0, "device":RUN_DEVICE, "dtype":RUN_DTYPE, "tmix_backend":"triton_bighead" })
tmix_triton_bighead.load_from_model_state_dict(model_weight, 0)

tmix_cuda = RWKV7TimeMix({ "num_hidden_layers":24, "hidden_size":hidden_size, "layer_id":0, "device":RUN_DEVICE, "dtype":RUN_DTYPE, "tmix_backend":"cuda" })
tmix_cuda.load_from_model_state_dict(model_weight, 0)

tmix_fla = RWKV7TimeMix({ "num_hidden_layers":24, "hidden_size":hidden_size, "layer_id":0, "device":RUN_DEVICE, "dtype":RUN_DTYPE, "tmix_backend":"fla" })
tmix_fla.load_from_model_state_dict(model_weight, 0)

tmix_fla_fused = RWKV7TimeMix({ "num_hidden_layers":24, "hidden_size":hidden_size, "layer_id":0, "device":RUN_DEVICE, "dtype":RUN_DTYPE, "tmix_backend":"fla_fused" })
tmix_fla_fused.load_from_model_state_dict(model_weight, 0)

print(f"### Testing the tmix blocks for {TEST_STEPS} steps")

### Testing the tmix blocks for 10 steps


In [3]:
### TMix
with torch.inference_mode():

    # This is a warmup
    t0 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_shift, t_wkv, v_first = tmix_pytorch.forward_with_reduce_compile(x_state_1, t_shift, tmix_wkv_1, v_first)
    t2 = time.time()
    print(f'1 tmix_pytorch reduce-compile forward passes (warmup): {(t2-t0)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

    # The actual run
    t1 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_shift, t_wkv, v_first = tmix_pytorch.forward_with_reduce_compile(x_state_1, t_shift, tmix_wkv_1, v_first)
    t2 = time.time()
    print(f'1 tmix_pytorch reduce-compile forward passes (normal): {(t2-t1)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

1 tmix_pytorch reduce-compile forward passes (warmup): 2428.707981109619 ms (cuda:0, torch.bfloat16)
1 tmix_pytorch reduce-compile forward passes (normal): 2041.679072380066 ms (cuda:0, torch.bfloat16)


In [4]:
### TMix
with torch.inference_mode():

    # This is a warmup
    t0 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_shift, t_wkv, v_first = tmix_pytorch_chunk.forward_with_reduce_compile(x_state_1, t_shift, tmix_wkv_1, v_first)
    t2 = time.time()
    print(f'1 tmix_pytorch_chunk reduce-compile forward passes (warmup): {(t2-t0)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

    # The actual run
    t1 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_shift, t_wkv, v_first = tmix_pytorch_chunk.forward_with_reduce_compile(x_state_1, t_shift, tmix_wkv_1, v_first)
    t2 = time.time()
    print(f'1 tmix_pytorch_chunk reduce-compile forward passes (normal): {(t2-t1)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

skipping cudagraphs due to mutated inputs (1 instances). Found from : 
   File "/home/recursal/rwkv-prj/RWKV-block/test/v7_goose/../../rwkv_block/v7_goose/block/kernel/rwkv7_attn_pytorch.py", line 197, in torch_dynamo_resume_in_rwkv7_attn_pytorch_v2_chunk_w_compile_break_at_183
    xx[:] = (wkv_xx.to(dtype=xx.dtype) @ r.view(BATCH_SIZE,SEQ_LEN,N_HEAD,HEAD_SIZE,1)).view(BATCH_SIZE,SEQ_LEN,N_HEAD*HEAD_SIZE)



1 tmix_pytorch_chunk reduce-compile forward passes (warmup): 1136.0665798187256 ms (cuda:0, torch.bfloat16)
1 tmix_pytorch_chunk reduce-compile forward passes (normal): 900.4284858703613 ms (cuda:0, torch.bfloat16)


In [5]:
### TMix
with torch.inference_mode():

    # This is a warmup
    t0 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_shift, t_wkv, v_first = tmix_triton.forward_with_reduce_compile(x_state_1, t_shift, tmix_wkv_1, v_first)
    t2 = time.time()
    print(f'1 tmix_triton reduce-compile forward passes (warmup): {(t2-t0)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

    # The actual run
    t1 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_shift, t_wkv, v_first = tmix_triton.forward_with_reduce_compile(x_state_1, t_shift, tmix_wkv_1, v_first)
    t2 = time.time()
    print(f'1 tmix_triton reduce-compile forward passes (normal): {(t2-t1)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

1 tmix_triton reduce-compile forward passes (warmup): 166.37885570526123 ms (cuda:0, torch.bfloat16)
1 tmix_triton reduce-compile forward passes (normal): 18.13671588897705 ms (cuda:0, torch.bfloat16)


In [6]:
### TMix
with torch.inference_mode():

    # This is a warmup
    t0 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_shift, t_wkv, v_first = tmix_triton_bighead.forward_with_reduce_compile(x_state_1, t_shift, tmix_wkv_1, v_first)
    t2 = time.time()
    print(f'1 tmix_triton_bighead reduce-compile forward passes (warmup): {(t2-t0)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

    # The actual run
    t1 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_shift, t_wkv, v_first = tmix_triton_bighead.forward_with_reduce_compile(x_state_1, t_shift, tmix_wkv_1, v_first)
    t2 = time.time()
    print(f'1 tmix_triton_bighead reduce-compile forward passes (normal): {(t2-t1)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

1 tmix_triton_bighead reduce-compile forward passes (warmup): 184.83824729919434 ms (cuda:0, torch.bfloat16)
1 tmix_triton_bighead reduce-compile forward passes (normal): 18.22197437286377 ms (cuda:0, torch.bfloat16)


In [7]:
### TMix
with torch.inference_mode():

    # This is a warmup
    t0 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_shift, t_wkv, v_first = tmix_cuda.forward_with_reduce_compile(x_state_1, t_shift, tmix_wkv_1, v_first)
    t2 = time.time()
    print(f'1 tmix_cuda reduce-compile forward passes (warmup): {(t2-t0)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

    # The actual run
    t1 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_shift, t_wkv, v_first = tmix_cuda.forward_with_reduce_compile(x_state_1, t_shift, tmix_wkv_1, v_first)
    t2 = time.time()
    print(f'1 tmix_cuda reduce-compile forward passes (normal): {(t2-t1)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

Using /home/recursal/.cache/torch_extensions/py312_cu121 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/recursal/.cache/torch_extensions/py312_cu121/state_wind_backstepping/build.ninja...
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
Building extension module state_wind_backstepping...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
Loading extension module state_wind_backstepping...


ninja: no work to do.
1 tmix_cuda reduce-compile forward passes (warmup): 102.49412059783936 ms (cuda:0, torch.bfloat16)
1 tmix_cuda reduce-compile forward passes (normal): 30.61971664428711 ms (cuda:0, torch.bfloat16)


In [8]:
### TMix
with torch.inference_mode():

    # This is a warmup
    t0 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_shift, t_wkv, v_first = tmix_fla.forward_with_reduce_compile(x_state_1, t_shift, tmix_wkv_1, v_first)
    t2 = time.time()
    print(f'1 tmix_fla reduce-compile forward passes (warmup): {(t2-t0)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

    # The actual run
    t1 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_shift, t_wkv, v_first = tmix_fla.forward_with_reduce_compile(x_state_1, t_shift, tmix_wkv_1, v_first)
    t2 = time.time()
    print(f'1 tmix_fla reduce-compile forward passes (normal): {(t2-t1)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

  from .autonotebook import tqdm as notebook_tqdm


1 tmix_fla reduce-compile forward passes (warmup): 2700.0231742858887 ms (cuda:0, torch.bfloat16)
1 tmix_fla reduce-compile forward passes (normal): 19.12992000579834 ms (cuda:0, torch.bfloat16)


In [9]:
### TMix
with torch.inference_mode():

    # This is a warmup
    t0 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_shift, t_wkv, v_first = tmix_fla_fused.forward_with_reduce_compile(x_state_1, t_shift, tmix_wkv_1, v_first)
    t2 = time.time()
    print(f'1 tmix_fla_fused reduce-compile forward passes (warmup): {(t2-t0)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

    # The actual run
    t1 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_shift, t_wkv, v_first = tmix_fla_fused.forward_with_reduce_compile(x_state_1, t_shift, tmix_wkv_1, v_first)
    t2 = time.time()
    print(f'1 tmix_fla_fused reduce-compile forward passes (normal): {(t2-t1)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

1 tmix_fla_fused reduce-compile forward passes (warmup): 1176.1919736862183 ms (cuda:0, torch.bfloat16)
1 tmix_fla_fused reduce-compile forward passes (normal): 21.772289276123047 ms (cuda:0, torch.bfloat16)
