In [1]:
# Configure the parent path to be the proj folder
import sys, os, torch, time
sys.path.append('../../')

# Import the model classes
from rwkv_block.v6_finch.model.rwkv6_finch_model import RWKV6FinchModel
from rwkv_block.v6_finch.model.rwkv6_finch_config_map import RWKV6FinchConfigMap

# File to load
MODEL_FILENAME="v6-1B6.pth"

# Run device, and run dtype to use
RUN_DEVICE="cpu"
RUN_DTYPE=torch.bfloat16

# Check for cuda device
if torch.cuda.is_available():
    RUN_DEVICE="cuda:0"

# Check if the reference weights exists
assert os.path.exists(f"./.model/{MODEL_FILENAME}"), "The reference weights does not exist. Please download it first (00-model-download.ipynb)"

# Loads the model weights
model_weight = torch.load(f"./.model/{MODEL_FILENAME}", map_location='cpu', weights_only=True, mmap=True)

# Model filename
print(f"### Model filename: {MODEL_FILENAME}")

# Lets get the n_dim, and setup the test module
n_dim = model_weight['emb.weight'].shape[1]
print(f"### Model n_dim: {n_dim}")

# List the model weights keys, and their shapes
print(f"### model weights keys:")
for key in model_weight:
    print(f"{key}: {model_weight[key].shape} - {model_weight[key].dtype}")

### Model filename: v6-1B6.pth
### Model n_dim: 2048
### model weights keys:
emb.weight: torch.Size([65536, 2048]) - torch.bfloat16
blocks.0.ln1.weight: torch.Size([2048]) - torch.bfloat16
blocks.0.ln1.bias: torch.Size([2048]) - torch.bfloat16
blocks.0.ln2.weight: torch.Size([2048]) - torch.bfloat16
blocks.0.ln2.bias: torch.Size([2048]) - torch.bfloat16
blocks.0.ln0.weight: torch.Size([2048]) - torch.bfloat16
blocks.0.ln0.bias: torch.Size([2048]) - torch.bfloat16
blocks.0.att.time_maa_x: torch.Size([1, 1, 2048]) - torch.bfloat16
blocks.0.att.time_maa_w: torch.Size([1, 1, 2048]) - torch.bfloat16
blocks.0.att.time_maa_k: torch.Size([1, 1, 2048]) - torch.bfloat16
blocks.0.att.time_maa_v: torch.Size([1, 1, 2048]) - torch.bfloat16
blocks.0.att.time_maa_r: torch.Size([1, 1, 2048]) - torch.bfloat16
blocks.0.att.time_maa_g: torch.Size([1, 1, 2048]) - torch.bfloat16
blocks.0.att.time_maa_w1: torch.Size([2048, 160]) - torch.bfloat16
blocks.0.att.time_maa_w2: torch.Size([5, 32, 2048]) - torch.bfl

In [2]:
BATCH_SIZE=1
TEST_COUNT=1000
TEST_LOOP=1
# GPU_COUNT=1

@torch.inference_mode()
def testForwardPass(smodel, compile_type=False):
    # Lets prepare the states accordingly
    in_state = smodel.get_init_state(BATCH_SIZE)
    out_state = smodel.get_init_state(BATCH_SIZE)
    x_tokens = torch.ones(BATCH_SIZE, 1, device=smodel.emb.weight.device, dtype=torch.int)
    # out_emb = torch.zeros(BATCH_SIZE, 1, n_dim, device=smodel.emb.weight.device, dtype=smodel.emb.weight.dtype)

    # Lets test more aggressively
    time0 = time.time()
    if compile_type == "default":
        for i in range(TEST_COUNT):
            smodel.forward_with_default_compile(x_tokens, in_state, out_state)
    elif compile_type == "reduce":
        for i in range(TEST_COUNT):
            smodel.forward_with_reduce_compile(x_tokens, in_state)
    else:
        for i in range(TEST_COUNT):
            smodel.forward(x_tokens, in_state, out_state)
    time1 = time.time()

    print("--")
    print(f"### Compile Type: {compile_type}")
    print("--")
    print(f"### (warmup) Avg time per token batch ({BATCH_SIZE}):", (time1-time0)*1000/TEST_COUNT, "ms")
    print(f"### (warmup) Avg tok/s batch ({BATCH_SIZE}) :", 1000/((time1-time0)*1000/TEST_COUNT), "tok/s")
    print(f"### (warmup) Avg time per token unbatched :", (time1-time0)*1000/TEST_COUNT/BATCH_SIZE, "ms")
    print(f"### (warmup) Avg tok/s unbatched :", 1000/((time1-time0)*1000/TEST_COUNT/BATCH_SIZE), "tok/s")
    # print(f"### (warmup) Avg tok/s unbatched / gpu :", (1000/((time1-time0)*1000/TEST_COUNT/BATCH_SIZE))/GPU_COUNT, "tok/s")

    for i in range(TEST_LOOP):
        time0 = time.time()
        if compile_type == "default":
            for i in range(TEST_COUNT):
                smodel.forward_with_default_compile(x_tokens, in_state, out_state)
        elif compile_type == "reduce":
            for i in range(TEST_COUNT):
                smodel.forward_with_reduce_compile(x_tokens, in_state)
        else:
            for i in range(TEST_COUNT):
                smodel.forward(x_tokens, in_state, out_state)
        time1 = time.time()
        print("--")
        print(f"### (actual) Avg time per token batch ({BATCH_SIZE}):", (time1-time0)*1000/TEST_COUNT, "ms")
        print(f"### (actual) Avg tok/s batch ({BATCH_SIZE}) :", 1000/((time1-time0)*1000/TEST_COUNT), "tok/s")
        print(f"### (actual) Avg time per token unbatched :", (time1-time0)*1000/TEST_COUNT/BATCH_SIZE, "ms")
        print(f"### (actual) Avg tok/s unbatched :", 1000/((time1-time0)*1000/TEST_COUNT/BATCH_SIZE), "tok/s")
        # print(f"### (actual) Avg tok/s unbatched / gpu :", (1000/((time1-time0)*1000/TEST_COUNT/BATCH_SIZE))/GPU_COUNT, "tok/s")

# Get the config
model_config = RWKV6FinchConfigMap.from_model_state_dict(model_weight, device=RUN_DEVICE, dtype=RUN_DTYPE)

# Log the config
print("### Model Config:")
print(model_config)

# Initialize the model instance
model_inst = RWKV6FinchModel(model_config)
model_inst.load_from_model_state_dict(model_weight)
model_state = model_inst.state_dict()

# List the model weights keys, and their shapes
print(f"### model weights keys:")
for key in model_state:
    print(f"{key}: {model_state[key].shape} - {model_state[key].dtype}")


### Model Config:
RWKV6FinchConfigMap(n_layer=24, n_dim=2048, head_size=64, head_size_divisor=8, dropout_rate=0.0, tmix_backend='auto', n_dim_ffn=7168, n_dim_att=2048, layer_id=None, n_head=32, device='cuda:0', dtype=torch.bfloat16, n_vocab=65536, init_state_wkv=False)
### model weights keys:
emb.weight: torch.Size([65536, 2048]) - torch.bfloat16
blocks.0.ln1.weight: torch.Size([2048]) - torch.bfloat16
blocks.0.ln1.bias: torch.Size([2048]) - torch.bfloat16
blocks.0.ln2.weight: torch.Size([2048]) - torch.bfloat16
blocks.0.ln2.bias: torch.Size([2048]) - torch.bfloat16
blocks.0.ln0.weight: torch.Size([2048]) - torch.bfloat16
blocks.0.ln0.bias: torch.Size([2048]) - torch.bfloat16
blocks.0.att.time_maa_x: torch.Size([1, 1, 2048]) - torch.bfloat16
blocks.0.att.time_maa_w: torch.Size([1, 1, 2048]) - torch.bfloat16
blocks.0.att.time_maa_k: torch.Size([1, 1, 2048]) - torch.bfloat16
blocks.0.att.time_maa_v: torch.Size([1, 1, 2048]) - torch.bfloat16
blocks.0.att.time_maa_r: torch.Size([1, 1, 2048

In [3]:
# Test the single token forward pass
testForwardPass(model_inst)

  from .autonotebook import tqdm as notebook_tqdm


--
### Compile Type: False
--
### (warmup) Avg time per token batch (1): 52.967305183410645 ms
### (warmup) Avg tok/s batch (1) : 18.87957102097767 tok/s
### (warmup) Avg time per token unbatched : 52.967305183410645 ms
### (warmup) Avg tok/s unbatched : 18.87957102097767 tok/s
--
### (actual) Avg time per token batch (1): 40.08279037475586 ms
### (actual) Avg tok/s batch (1) : 24.94836289216531 tok/s
### (actual) Avg time per token unbatched : 40.08279037475586 ms
### (actual) Avg tok/s unbatched : 24.94836289216531 tok/s


In [4]:
# Test the single token forward pass
testForwardPass(model_inst, "default")

--
### Compile Type: default
--
### (warmup) Avg time per token batch (1): 87.02363443374634 ms
### (warmup) Avg tok/s batch (1) : 11.491131191048215 tok/s
### (warmup) Avg time per token unbatched : 87.02363443374634 ms
### (warmup) Avg tok/s unbatched : 11.491131191048215 tok/s
--
### (actual) Avg time per token batch (1): 44.50671911239624 ms
### (actual) Avg tok/s batch (1) : 22.46851756191291 tok/s
### (actual) Avg time per token unbatched : 44.50671911239624 ms
### (actual) Avg tok/s unbatched : 22.46851756191291 tok/s


In [5]:
# Test the single token forward pass
testForwardPass(model_inst, "reduce")

--
### Compile Type: reduce
--
### (warmup) Avg time per token batch (1): 82.70757722854614 ms
### (warmup) Avg tok/s batch (1) : 12.090790632600644 tok/s
### (warmup) Avg time per token unbatched : 82.70757722854614 ms
### (warmup) Avg tok/s unbatched : 12.090790632600644 tok/s
--
### (actual) Avg time per token batch (1): 33.59133315086365 ms
### (actual) Avg tok/s batch (1) : 29.769583586005712 tok/s
### (actual) Avg time per token unbatched : 33.59133315086365 ms
### (actual) Avg tok/s unbatched : 29.769583586005712 tok/s
