In [7]:
!pip3 install triton


Usage:   
  pip3 <command> [options]

Commands:
  install                     Install packages.
  download                    Download packages.
  uninstall                   Uninstall packages.
  freeze                      Output installed packages in requirements format.
  inspect                     Inspect the python environment.
  list                        List installed packages.
  show                        Show information about installed packages.
  check                       Verify installed packages have compatible dependencies.
  config                      Manage local and global configuration.
  search                      Search PyPI for packages.
  cache                       Inspect and manage pip's wheel cache.
  index                       Inspect information available from package indexes.
  wheel                       Build wheels from your requirements.
  hash                        Compute hashes of package archives.
  completion                  A helper c

In [2]:
# Configure the parent path to be the proj folder
import sys, os, torch, time
sys.path.append('../../')

# Import the block classes
from rwkv.v7_goose.block.rwkv7_time_mix import RWKV7TimeMix

# File to load
MODEL_FILENAME="v7-1B4.pth"

# Run device, and run dtype to use
RUN_DEVICE="cpu"
RUN_DTYPE=torch.bfloat16

# Check for cuda device
if torch.cuda.is_available():
    RUN_DEVICE="cuda:0"

# Check if the reference weights exists
assert os.path.exists(f"./.model/{MODEL_FILENAME}"), "The reference weights does not exist. Please download it first (00-model-download.ipynb)"

# Loads the model weights
model_weight = torch.load(f"./.model/{MODEL_FILENAME}", map_location='cpu', weights_only=True, mmap=True)

# Model filename
print(f"### Model filename: {MODEL_FILENAME}")

# Lets get the n_dim, and setup the test module
n_dim = model_weight['emb.weight'].shape[1]
print(f"### Model n_dim: {n_dim}")

# List the model weights keys, and their shapes
print(f"### model weights keys:")
for key in model_weight:
    print(f"{key}: {model_weight[key].shape} - {model_weight[key].dtype}")

ModuleNotFoundError: No module named 'triton'

In [None]:
# Initialize the channelmix state, and x state to test
IN_TOKENS_LEN=8192
x_state_0 = torch.ones(1, IN_TOKENS_LEN, n_dim, device=RUN_DEVICE, dtype=RUN_DTYPE)
x_state_1 = torch.ones(1, IN_TOKENS_LEN, n_dim, device=RUN_DEVICE, dtype=RUN_DTYPE)
x_state_2 = torch.ones(1, IN_TOKENS_LEN, n_dim, device=RUN_DEVICE, dtype=RUN_DTYPE)
tmix_shift_0 = torch.ones(1, n_dim, device=RUN_DEVICE, dtype=RUN_DTYPE)
tmix_shift_1 = torch.ones(1, n_dim, device=RUN_DEVICE, dtype=RUN_DTYPE)
tmix_wkv_0 = torch.ones(1, n_dim // 64, 64, 64, device=RUN_DEVICE, dtype=RUN_DTYPE)
tmix_wkv_1 = torch.ones(1, n_dim // 64, 64, 64, device=RUN_DEVICE, dtype=RUN_DTYPE)

# Build the cmix block
tmix = RWKV7TimeMix({ "n_layer":24, "n_dim":n_dim, "layer_id":0, "device":RUN_DEVICE, "dtype":RUN_DTYPE })
tmix.load_from_model_state_dict(model_weight, 0)

# Get the named parameters
tmix_params = tmix.named_parameters()
print(f"### tmix named parameters:")
for name, param in tmix_params:
    print(f"{name}: {param.shape} - {param.dtype}")

# Log each item shape
tmix_state = tmix.state_dict()
print(f"### tmix state keys:")
for key in tmix_state:
    print(f"tmix.{key}: {tmix_state[key].shape} - {tmix_state[key].dtype}")
print("----")

### tmix named parameters:
x_r: torch.Size([1, 1, 2048]) - torch.bfloat16
x_w: torch.Size([1, 1, 2048]) - torch.bfloat16
x_k: torch.Size([1, 1, 2048]) - torch.bfloat16
x_v: torch.Size([1, 1, 2048]) - torch.bfloat16
x_a: torch.Size([1, 1, 2048]) - torch.bfloat16
x_g: torch.Size([1, 1, 2048]) - torch.bfloat16
w1: torch.Size([2048, 96]) - torch.bfloat16
w2: torch.Size([96, 2048]) - torch.bfloat16
w0: torch.Size([1, 1, 2048]) - torch.bfloat16
a1: torch.Size([2048, 96]) - torch.bfloat16
a2: torch.Size([96, 2048]) - torch.bfloat16
a0: torch.Size([1, 1, 2048]) - torch.bfloat16
v1: torch.Size([2048, 64]) - torch.bfloat16
v2: torch.Size([64, 2048]) - torch.bfloat16
v0: torch.Size([1, 1, 2048]) - torch.bfloat16
g1: torch.Size([2048, 256]) - torch.bfloat16
g2: torch.Size([256, 2048]) - torch.bfloat16
k_k: torch.Size([1, 1, 2048]) - torch.bfloat16
k_a: torch.Size([1, 1, 2048]) - torch.bfloat16
r_k: torch.Size([32, 64]) - torch.bfloat16
receptance.weight: torch.Size([2048, 2048]) - torch.bfloat16
k

In [None]:
# Iteration to test
TEST_STEPS = 5

### TMix
with torch.inference_mode():

    # This is a warmup
    t0 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_shift, t_wkv, v_first = tmix(x_state_1, t_shift, tmix_wkv_1, v_first)
    t2 = time.time()
    print(f'1 tmix forward passes (warmup): {(t2-t0)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

    # The actual run
    t1 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_shift, t_wkv, v_first = tmix(x_state_1, t_shift, tmix_wkv_1, v_first)
    t2 = time.time()
    print(f'1 tmix forward passes (normal): {(t2-t1)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')


1 tmix forward passes (warmup): 16534.054040908813 ms (cpu, torch.bfloat16)
1 tmix forward passes (normal): 16628.37781906128 ms (cpu, torch.bfloat16)


In [None]:
# Iteration to test
TEST_STEPS = 5

### TMix
with torch.inference_mode():

    # This is a warmup
    t0 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_shift, t_wkv, v_first = tmix.forward_with_default_compile(x_state_1, t_shift, tmix_wkv_1, v_first, out_x, t_shift, t_wkv, v_first)
    t2 = time.time()
    print(f'1 tmix forward passes (warmup): {(t2-t0)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

    # The actual run
    t1 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_shift, t_wkv, v_first = tmix.forward_with_default_compile(x_state_1, t_shift, tmix_wkv_1, v_first, out_x, t_shift, t_wkv, v_first)
    t2 = time.time()
    print(f'1 tmix forward passes (compiled): {(t2-t1)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')


In [None]:
# Iteration to test
TEST_STEPS = 5

### TMix
with torch.inference_mode():

    # This is a warmup
    t0 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_shift, t_wkv, v_first = tmix.forward_with_reduce_compile(x_state_1, t_shift, tmix_wkv_1, v_first)
    t2 = time.time()
    print(f'1 tmix forward passes (warmup): {(t2-t0)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

    # The actual run
    t1 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_shift, t_wkv, v_first = tmix.forward_with_reduce_compile(x_state_1, t_shift, tmix_wkv_1, v_first)
    t2 = time.time()
    print(f'1 tmix forward passes (normal): {(t2-t1)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')


In [None]:
# Export tmix1 state dict
tmix_state = tmix.state_dict()

# Log each item shape
print(f"### tmix state keys:")
for key in tmix_state:
    print(f"tmix.{key}: {tmix_state[key].shape} - {tmix_state[key].dtype}")
print("----")

# Build the tmix block
tmix2 = RWKV7TimeMix({ "n_layer":24, "n_dim":n_dim, "layer_id":1, "tmix_backend":"torch", "device":RUN_DEVICE, "dtype":RUN_DTYPE })

# Load the state dict
tmix2.load_state_dict(tmix_state)

# Log each item shape
print(f"### tmix2 state keys:")
for key in tmix_state:
    print(f"tmix.{key}: {tmix_state[key].shape} - {tmix_state[key].dtype}")
print("----")