In [1]:
# Configure the parent path to be the proj folder
import sys, os, torch, time
sys.path.append('../../')

# Import the block classes
from rwkv_block.v5_eagle.block.rwkv5_time_mix import RWKV5TimeMix
# from rwkv_block.v5_eagle.block.rwkv5_block_config_map import RWKV5BlockConfigMap

# File to load
MODEL_FILENAME="v5-0B4.pth" #"v5-0B4.pth" #"EagleX-1_7T.pth"

# Run device, and run dtype to use
RUN_DEVICE="cpu"
RUN_DTYPE=torch.bfloat16

# Check for cuda device
if torch.cuda.is_available():
    RUN_DEVICE="cuda:0"

# Check if the reference weights exists
assert os.path.exists(f"./.model/{MODEL_FILENAME}"), "The reference weights does not exist. Please download it first (00-model-download.ipynb)"

# Loads the model weights
model_weight = torch.load(f"./.model/{MODEL_FILENAME}", map_location='cpu', weights_only=True, mmap=True)

# Model filename
print(f"### Model filename: {MODEL_FILENAME}")

# Lets get the hidden_size, and setup the test module
hidden_size = model_weight['emb.weight'].shape[1]
print(f"### Model hidden_size: {hidden_size}")

# List the model weights keys, and their shapes
print(f"### model weights keys:")
for key in model_weight:
    print(f"{key}: {model_weight[key].shape} - {model_weight[key].dtype}")

### Model filename: v5-0B4.pth
### Model hidden_size: 1024
### model weights keys:
emb.weight: torch.Size([65536, 1024]) - torch.bfloat16
blocks.0.ln1.weight: torch.Size([1024]) - torch.bfloat16
blocks.0.ln1.bias: torch.Size([1024]) - torch.bfloat16
blocks.0.ln2.weight: torch.Size([1024]) - torch.bfloat16
blocks.0.ln2.bias: torch.Size([1024]) - torch.bfloat16
blocks.0.ln0.weight: torch.Size([1024]) - torch.bfloat16
blocks.0.ln0.bias: torch.Size([1024]) - torch.bfloat16
blocks.0.att.time_mix_k: torch.Size([1, 1, 1024]) - torch.bfloat16
blocks.0.att.time_mix_v: torch.Size([1, 1, 1024]) - torch.bfloat16
blocks.0.att.time_mix_r: torch.Size([1, 1, 1024]) - torch.bfloat16
blocks.0.att.time_mix_g: torch.Size([1, 1, 1024]) - torch.bfloat16
blocks.0.att.time_decay: torch.Size([16, 64]) - torch.bfloat16
blocks.0.att.time_faaaa: torch.Size([16, 64]) - torch.bfloat16
blocks.0.att.receptance.weight: torch.Size([1024, 1024]) - torch.bfloat16
blocks.0.att.key.weight: torch.Size([1024, 1024]) - torch.

In [2]:
# Initialize the channelmix state, and x state to test
IN_TOKENS_LEN=8192
x_state_0 = torch.ones(1, IN_TOKENS_LEN, hidden_size, device=RUN_DEVICE, dtype=RUN_DTYPE)
x_state_1 = torch.ones(1, IN_TOKENS_LEN, hidden_size, device=RUN_DEVICE, dtype=RUN_DTYPE)
tmix_shift_0 = torch.ones(1, hidden_size, device=RUN_DEVICE, dtype=RUN_DTYPE)
tmix_shift_1 = torch.ones(1, hidden_size, device=RUN_DEVICE, dtype=RUN_DTYPE)
tmix_wkv_0 = torch.ones(1, hidden_size // 64, 64, 64, device=RUN_DEVICE, dtype=RUN_DTYPE)
tmix_wkv_1 = torch.ones(1, hidden_size // 64, 64, 64, device=RUN_DEVICE, dtype=RUN_DTYPE)

# Build the cmix block
tmix = RWKV5TimeMix({ "num_hidden_layers":24, "hidden_size":hidden_size, "layer_id":0, "tmix_backend":"fla", "device":RUN_DEVICE, "dtype":RUN_DTYPE })
tmix.load_from_model_state_dict(model_weight, 0)

# Log each item shape
tmix_state = tmix.state_dict()
print(f"### tmix state keys:")
for key in tmix_state:
    print(f"tmix.{key}: {tmix_state[key].shape} - {tmix_state[key].dtype}")
print("----")

### tmix state keys:
tmix.time_mix_k: torch.Size([1, 1, 1024]) - torch.bfloat16
tmix.time_mix_v: torch.Size([1, 1, 1024]) - torch.bfloat16
tmix.time_mix_r: torch.Size([1, 1, 1024]) - torch.bfloat16
tmix.time_mix_g: torch.Size([1, 1, 1024]) - torch.bfloat16
tmix.time_decay: torch.Size([16, 64]) - torch.bfloat16
tmix.time_faaaa: torch.Size([16, 64]) - torch.bfloat16
tmix.receptance.weight: torch.Size([1024, 1024]) - torch.bfloat16
tmix.key.weight: torch.Size([1024, 1024]) - torch.bfloat16
tmix.value.weight: torch.Size([1024, 1024]) - torch.bfloat16
tmix.output.weight: torch.Size([1024, 1024]) - torch.bfloat16
tmix.gate.weight: torch.Size([1024, 1024]) - torch.bfloat16
tmix.ln_x.weight: torch.Size([1024]) - torch.bfloat16
tmix.ln_x.bias: torch.Size([1024]) - torch.bfloat16
----


In [3]:
# Iteration to test
TEST_STEPS = 1000

### TMix
with torch.inference_mode():

    # This is a warmup
    t0 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    for i in range(TEST_STEPS):
        out_x, t_shift, t_wkv = tmix(x_state_1, t_shift, tmix_wkv_1)
    t2 = time.time()
    print(f'1 tmix forward passes (warmup): {(t2-t0)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

    # The actual run
    t1 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    for i in range(TEST_STEPS):
        out_x, t_shift, t_wkv = tmix(x_state_1, t_shift, tmix_wkv_1)
    t2 = time.time()
    print(f'1 tmix forward passes (normal): {(t2-t1)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')


  from .autonotebook import tqdm as notebook_tqdm


1 tmix forward passes (warmup): 13.742965936660767 ms (cuda:0, torch.bfloat16)
1 tmix forward passes (normal): 3.2587087154388428 ms (cuda:0, torch.bfloat16)


In [4]:
# Iteration to test
TEST_STEPS = 1000

### TMix
with torch.inference_mode():

    # This is a warmup
    t0 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    for i in range(TEST_STEPS):
        out_x, t_shift, t_wkv = tmix.forward_with_default_compile(x_state_1, t_shift, tmix_wkv_1, out_x, t_shift, t_wkv)
    t2 = time.time()
    print(f'1 tmix forward passes (warmup): {(t2-t0)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

    # The actual run
    t1 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    for i in range(TEST_STEPS):
        out_x, t_shift, t_wkv = tmix.forward_with_default_compile(x_state_1, t_shift, tmix_wkv_1, out_x, t_shift, t_wkv)
    t2 = time.time()
    print(f'1 tmix forward passes (compiled): {(t2-t1)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')


1 tmix forward passes (warmup): 6.560767412185669 ms (cuda:0, torch.bfloat16)
1 tmix forward passes (compiled): 2.9071950912475586 ms (cuda:0, torch.bfloat16)


In [5]:
# Iteration to test
TEST_STEPS = 1000

### TMix
with torch.inference_mode():

    # This is a warmup
    t0 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    for i in range(TEST_STEPS):
        out_x, t_shift, t_wkv = tmix.forward_with_reduce_compile(x_state_1, t_shift, tmix_wkv_1)
    t2 = time.time()
    print(f'1 tmix forward passes (warmup): {(t2-t0)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

    # The actual run
    t1 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    for i in range(TEST_STEPS):
        out_x, t_shift, t_wkv = tmix.forward_with_reduce_compile(x_state_1, t_shift, tmix_wkv_1)
    t2 = time.time()
    print(f'1 tmix forward passes (normal): {(t2-t1)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')


1 tmix forward passes (warmup): 3.9972195625305176 ms (cuda:0, torch.bfloat16)
1 tmix forward passes (normal): 2.9207406044006348 ms (cuda:0, torch.bfloat16)
