In [1]:
# Configure the parent path to be the proj folder
import sys, os, torch, time
sys.path.append('../../')

# Import the block classes
from block.v5_eagle.rwkv5_time_mix import RWKV5TimeMix
from block.v5_eagle.rwkv5_block_config_map import RWKV5BlockConfigMap

# File to load
MODEL_FILENAME="v5-0B4.pth"

# Run device, and run dtype to use
RUN_DEVICE="cpu"
RUN_DTYPE=torch.bfloat16

# Check for cuda device
if torch.cuda.is_available():
    RUN_DEVICE="cuda:0"

# Check if the reference weights exists
assert os.path.exists(f"./.model/{MODEL_FILENAME}"), "The reference weights does not exist. Please download it first (00-model-download.ipynb)"

# Loads the model weights
model_weight = torch.load(f"./.model/{MODEL_FILENAME}", map_location='cpu', weights_only=True, mmap=True)

# Model filename
print(f"### Model filename: {MODEL_FILENAME}")

# Lets get the n_dim, and setup the test module
n_dim = model_weight['emb.weight'].shape[1]
print(f"### Model n_dim: {n_dim}")

# List the model weights keys, and their shapes
print(f"### model weights keys:")
for key in model_weight:
    print(f"{key}: {model_weight[key].shape} - {model_weight[key].dtype}")

### Model filename: v5-0B4.pth
### Model n_dim: 1024
### model weights keys:
emb.weight: torch.Size([65536, 1024]) - torch.bfloat16
blocks.0.ln1.weight: torch.Size([1024]) - torch.bfloat16
blocks.0.ln1.bias: torch.Size([1024]) - torch.bfloat16
blocks.0.ln2.weight: torch.Size([1024]) - torch.bfloat16
blocks.0.ln2.bias: torch.Size([1024]) - torch.bfloat16
blocks.0.ln0.weight: torch.Size([1024]) - torch.bfloat16
blocks.0.ln0.bias: torch.Size([1024]) - torch.bfloat16
blocks.0.att.time_mix_k: torch.Size([1, 1, 1024]) - torch.bfloat16
blocks.0.att.time_mix_v: torch.Size([1, 1, 1024]) - torch.bfloat16
blocks.0.att.time_mix_r: torch.Size([1, 1, 1024]) - torch.bfloat16
blocks.0.att.time_mix_g: torch.Size([1, 1, 1024]) - torch.bfloat16
blocks.0.att.time_decay: torch.Size([16, 64]) - torch.bfloat16
blocks.0.att.time_faaaa: torch.Size([16, 64]) - torch.bfloat16
blocks.0.att.receptance.weight: torch.Size([1024, 1024]) - torch.bfloat16
blocks.0.att.key.weight: torch.Size([1024, 1024]) - torch.bfloat

In [2]:
# Initialize the channelmix state, and x state to test
x_state_0 = torch.ones(1, 1, n_dim, device=RUN_DEVICE, dtype=RUN_DTYPE)
x_state_1 = torch.ones(1, 1, n_dim, device=RUN_DEVICE, dtype=RUN_DTYPE)
tmix_shift_0 = torch.ones(1, n_dim, device=RUN_DEVICE, dtype=RUN_DTYPE)
tmix_shift_1 = torch.ones(1, n_dim, device=RUN_DEVICE, dtype=RUN_DTYPE)
tmix_wkv_0 = torch.ones(1, n_dim // 64, 64, 64, device=RUN_DEVICE, dtype=RUN_DTYPE)
tmix_wkv_1 = torch.ones(1, n_dim // 64, 64, 64, device=RUN_DEVICE, dtype=RUN_DTYPE)

# Build the cmix block
tmix = RWKV5TimeMix({ "n_layer":24, "n_dim":n_dim, "layer_id":0, "device":RUN_DEVICE, "dtype":RUN_DTYPE })
tmix.load_from_model_state_dict(model_weight, 0)

# Log each item shape
tmix_state = tmix.state_dict()
print(f"### tmix state keys:")
for key in tmix_state:
    print(f"tmix.{key}: {tmix_state[key].shape} - {tmix_state[key].dtype}")
print("----")

### tmix state keys:
tmix.time_mix_k: torch.Size([1, 1, 1024]) - torch.bfloat16
tmix.time_mix_v: torch.Size([1, 1, 1024]) - torch.bfloat16
tmix.time_mix_r: torch.Size([1, 1, 1024]) - torch.bfloat16
tmix.time_mix_g: torch.Size([1, 1, 1024]) - torch.bfloat16
tmix.time_decay: torch.Size([16, 64]) - torch.bfloat16
tmix.time_faaaa: torch.Size([16, 64]) - torch.bfloat16
tmix.receptance.weight: torch.Size([1024, 1024]) - torch.bfloat16
tmix.key.weight: torch.Size([1024, 1024]) - torch.bfloat16
tmix.value.weight: torch.Size([1024, 1024]) - torch.bfloat16
tmix.output.weight: torch.Size([1024, 1024]) - torch.bfloat16
tmix.gate.weight: torch.Size([1024, 1024]) - torch.bfloat16
tmix.ln_x.weight: torch.Size([1024]) - torch.bfloat16
tmix.ln_x.bias: torch.Size([1024]) - torch.bfloat16
----


In [3]:
# Iteration to test
TEST_STEPS = 1000

### TMix

# This is a warmup
t0 = time.time()
out_x = x_state_0
t_shift = tmix_shift_0
t_wkv = tmix_wkv_0
for i in range(TEST_STEPS):
    out_x, t_shift, t_wkv = tmix(x_state_1, t_shift, tmix_wkv_1)

# The actual run
t1 = time.time()
out_x = x_state_0
t_shift = tmix_shift_0
t_wkv = tmix_wkv_0
for i in range(TEST_STEPS):
    out_x, t_shift, t_wkv = tmix(x_state_1, t_shift, tmix_wkv_1)
t2 = time.time()
print(f'1 tmix forward passes (normal): {(t2-t1)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')


1 tmix forward passes (normal): 0.5500545501708984 ms (cuda:0, torch.bfloat16)


In [4]:
# Iteration to test
TEST_STEPS = 1000

### TMix

# This is a warmup
t0 = time.time()
out_x = x_state_0
t_shift = tmix_shift_0
t_wkv = tmix_wkv_0
for i in range(TEST_STEPS):
    out_x, t_shift, t_wkv = tmix.forward_with_compile(x_state_1, t_shift, tmix_wkv_1, out_x, t_shift, t_wkv)

# The actual run
t1 = time.time()
out_x = x_state_0
t_shift = tmix_shift_0
t_wkv = tmix_wkv_0
for i in range(TEST_STEPS):
    out_x, t_shift, t_wkv = tmix.forward_with_compile(x_state_1, t_shift, tmix_wkv_1, out_x, t_shift, t_wkv)
t2 = time.time()
print(f'1 tmix forward passes (compiled): {(t2-t1)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')


TypeError: RWKV5TimeMix.forward_with_compile() missing 3 required positional arguments: 'out_x', 'shift_state_out', and 'wkv_state_out'