## Initialization and helper functions

In [1]:
import numpy as np

def softmax(x):
    """
    Compute the softmax of vector x.
    so along axis 1
    """
    # normalize
    x = x - np.min(x, axis=1, keepdims=True)
    x = x/np.max(x, axis=1, keepdims=True)*16

    exp_x = np.exp(x)
    return exp_x / np.sum(exp_x, axis=1, keepdims=True)

# Assuming DMODEL, N, BITWIDTH, etc., are defined
DMODEL = 64  # Example values, adjust as necessary
N = 64       # Example values, adjust as necessary
BITWIDTH = 4 # Bit-width for data_t

# not necessary, will do in full precision

"""if DMODEL == 16:
    BITWIDTH3 = 2*4 + 4
elif DMODEL == 32:
    BITWIDTH3 = 2*4 + 5
elif DMODEL == 64:
    BITWIDTH3 = 2*4 + 6
else:
    raise ValueError("DMODEL must be 16, 32, or 64")"""

None

In [2]:
def normalize(x):
    # normalize matrix along axis 1
    x = x - np.min(x, axis=1, keepdims=True)
    x = x/np.max(x, axis=1, keepdims=True)*16
    return x

In [3]:
def print_matrix(matrix):
    #print(f"{name} = np.array([")
    for row in matrix:
        for element in row:
            print(f"{element} ", end="")
        print()
    

## Load same weights as in the cpp code

In [4]:
import re
import numpy as np


def read_weights(file_path):
    with open(file_path, "r") as file:
        content = file.read()

    # Regular expression to find matrices
    matrices = re.findall(r"{\{(.*?)\}\}", content, re.DOTALL)

    make_int = lambda x: int(x.strip().strip("{").strip("}"))

    # Process and store each matrix
    Q_W = np.array(
        [
            list(map(make_int, row.strip(",").split(",")))
            for row in matrices[0].split("\n")
            if row.strip() != ""
        ]
    )
    K_W = np.array(
        [
            list(map(make_int, row.strip(",").split(",")))
            for row in matrices[1].split("\n")
            if row.strip() != ""
        ]
    )
    V_W = np.array(
        [
            list(map(make_int, row.strip(",").split(",")))
            for row in matrices[2].split("\n")
            if row.strip() != ""
        ]
    )

    return Q_W, K_W, V_W


file_path = f"weights{DMODEL}.h" 
Q_W, K_W, V_W = read_weights(file_path)


## Create Attention Test Cases

In [5]:
import numpy as np

# Assuming DMODEL, N, Q_W, K_W, V_W are defined

# Step 1: Generate and normalize random tokens
tokens = np.random.rand(N, DMODEL) * 15

# Normalize Q, K, V matrices
Q_W_norm = Q_W #/ np.linalg.norm(Q_W, axis=1, keepdims=True)
K_W_norm = K_W #/ np.linalg.norm(K_W, axis=1, keepdims=True)
V_W_norm = V_W #/ np.linalg.norm(V_W, axis=1, keepdims=True)

# Step 2: Apply QKV transformations
Q = tokens @ Q_W_norm
K = tokens @ K_W_norm
V = tokens @ V_W_norm

print("Q")
print(Q)
print("K")
print(K)
print("V")
print(V)
print("--------------------")

K = normalize(K)
Q = normalize(Q)

# Step 3: Compute attention using softmax
K_T = K.T
attention_scores = Q @ K_T #/ np.sqrt(DMODEL)


print("attention scores")
print(attention_scores)

attention_weights = softmax(attention_scores)

print(attention_weights)

V = normalize(V)
attention_output = attention_weights @ V

print(attention_output.astype(int))

# Step 4: Normalize output tokens back into the range 0-15
#output_tokens = np.clip(np.round((attention_output - np.min(attention_output)) / (np.max(attention_output) - np.min(attention_output)) * 15), 0, 15).astype(int)

# Output tokens
#print(output_tokens)


Q
[[3154.51003908 2938.80078764 3018.41099498 ... 3105.40019308
  3170.63162208 2938.0190853 ]
 [4402.33972385 3793.56826785 3681.596246   ... 4123.74392884
  4239.00365098 3802.04750462]
 [4386.10397641 4040.04314257 3587.10217719 ... 3485.40746077
  4001.67878608 3801.30796987]
 ...
 [3997.2773218  3855.05791853 3588.35123864 ... 3903.09782507
  4112.95018134 3770.63129694]
 [4587.62701551 4164.29899085 3946.99502819 ... 4237.80621341
  4570.78982451 3855.47087746]
 [4073.25245531 3741.26278678 3215.79871689 ... 3377.73011459
  3988.62479269 2977.50190218]]
K
[[3433.34272132 2498.82532154 3041.59535035 ... 2547.93298738
  3082.77737138 3359.20386857]
 [4554.6810195  3601.4681042  3977.45159523 ... 3476.1737981
  4257.12666083 4302.86638462]
 [4012.95343964 3133.83872324 3595.8792711  ... 3579.37227315
  4172.78866167 4235.95321087]
 ...
 [4207.01311806 3667.79910154 3626.77605877 ... 3369.82965489
  4136.94448087 4190.32236215]
 [4387.28677312 3681.24361104 4038.6992922  ... 3757.886

In [6]:
# save 10 input output pairs to file
# as N*DMODEL

file = open("generate_tests_input64.txt", "w")
file2 = open("generate_tests_output64.txt", "w")

np.random.seed(0)

for i in range(10):
    tokens = np.random.rand(N, DMODEL) * 15
    Q = normalize(tokens) @ normalize(Q_W.T)
    K = normalize(tokens) @ normalize(K_W.T)
    V = normalize(tokens) @ normalize(V_W.T)

    K_T = K.T
    attention_scores = Q @ K_T #/ np.sqrt(DMODEL)

    attention_weights = softmax(attention_scores)

    V = normalize(V)
    attention_output = attention_weights @ V

    # print max in each row
    print("max index in each row")

    for j in range(len(attention_scores)):
        print(np.argmax(attention_scores[j]))

    print("--------------------")
    print_matrix(V.astype(int))
    print("--------------------")

    flat_tokens = tokens.flatten()
    flat_output = attention_output.flatten()

    for j in range(len(flat_tokens)):
        file.write(str(int(flat_tokens[j])) + "\n")
    #file.write("\n")

    for j in range(len(flat_output)):
        file2.write(str(int(flat_output[j])) + "\n")
    #file2.write("\n")

file.close()
file2.close()

max index in each row
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
--------------------
9 14 7 6 0 11 6 4 7 6 7 11 10 9 12 9 7 15 13 9 14 7 5 15 4 16 9 14 7 6 6 12 10 12 9 13 6 14 12 14 9 8 10 11 7 11 2 11 11 11 9 4 8 5 6 10 9 8 8 10 12 10 7 6 
9 11 5 6 0 9 7 4 8 8 7 8 10 7 10 3 9 11 16 10 13 2 7 15 3 15 9 11 8 9 6 13 7 14 9 15 6 13 11 15 9 12 8 9 11 13 3 9 8 13 11 3 7 10 8 2 9 0 8 10 13 10 2 5 
12 13 7 6 4 10 9 3 6 6 5 9 8 10 7 8 9 10 13 7 12 5 0 12 2 11 5 11 8 5 7 10 9 9 7 12 3 10 10 12 11 8 8 7 6 10 3 9 6 12 8 4 9 9 6 10 9 3 9 7 16 9 2 8 
9 13 6 5 0 6 9 3 5 8 8 8 9 7 8 4 8 13 12 6 10 7 2 16 6 11 11 7 8 9 5 9 10 10 11 15 3 14 9 12 9 8 13 7 9 9 6 10 7 11 11 2 6 5 8 7 8 3 11 12 14 10 8 5 
6 13 3 5 0 7 6 1 1 1 9 8 7 5 5 4 7 13 12 8 10 5 0 14 5 13 7 9 6 6 4 8 8 9 10 13 4 11 10 12 8 7 10 5 9 9 3 12 6 9 10 4 9 2 6 5 8 5 7 8 16 6 3 1 
10 14 5 6 0