## Initialization and helper functions

In [1]:
import numpy as np

def softmax(x):
    """
    Compute the softmax of vector x.
    so along axis 1
    """
    # normalize
    x = x - np.min(x, axis=1, keepdims=True)
    x = x/np.max(x, axis=1, keepdims=True)*16

    exp_x = np.exp(x)
    return exp_x / np.sum(exp_x, axis=1, keepdims=True)

# Assuming DMODEL, N, BITWIDTH, etc., are defined
DMODEL = 32  # Example values, adjust as necessary
N = 32       # Example values, adjust as necessary
BITWIDTH = 4 # Bit-width for data_t

# not necessary, will do in full precision

"""if DMODEL == 16:
    BITWIDTH3 = 2*4 + 4
elif DMODEL == 32:
    BITWIDTH3 = 2*4 + 5
elif DMODEL == 64:
    BITWIDTH3 = 2*4 + 6
else:
    raise ValueError("DMODEL must be 16, 32, or 64")"""

None

In [2]:
def normalize(x):
    # normalize matrix along axis 1
    x = x - np.min(x, axis=1, keepdims=True)
    x = x/np.max(x, axis=1, keepdims=True)*16
    return x

In [3]:
def print_matrix(matrix):
    #print(f"{name} = np.array([")
    for row in matrix:
        for element in row:
            print(f"{element} ", end="")
        print()
    

## Load same weights as in the cpp code

In [4]:
import re
import numpy as np


def read_weights(file_path):
    with open(file_path, "r") as file:
        content = file.read()

    # Regular expression to find matrices
    matrices = re.findall(r"{\{(.*?)\}\}", content, re.DOTALL)

    make_int = lambda x: int(x.strip().strip("{").strip("}"))

    # Process and store each matrix
    Q_W = np.array(
        [
            list(map(make_int, row.strip(",").split(",")))
            for row in matrices[0].split("\n")
            if row.strip() != ""
        ]
    )
    K_W = np.array(
        [
            list(map(make_int, row.strip(",").split(",")))
            for row in matrices[1].split("\n")
            if row.strip() != ""
        ]
    )
    V_W = np.array(
        [
            list(map(make_int, row.strip(",").split(",")))
            for row in matrices[2].split("\n")
            if row.strip() != ""
        ]
    )

    return Q_W, K_W, V_W


file_path = f"weights{DMODEL}.h" 
Q_W, K_W, V_W = read_weights(file_path)


## Create Attention Test Cases

In [5]:
import numpy as np

# Assuming DMODEL, N, Q_W, K_W, V_W are defined

# Step 1: Generate and normalize random tokens
tokens = np.random.rand(N, DMODEL) * 15

# Normalize Q, K, V matrices
Q_W_norm = Q_W #/ np.linalg.norm(Q_W, axis=1, keepdims=True)
K_W_norm = K_W #/ np.linalg.norm(K_W, axis=1, keepdims=True)
V_W_norm = V_W #/ np.linalg.norm(V_W, axis=1, keepdims=True)

# Step 2: Apply QKV transformations
Q = tokens @ Q_W_norm
K = tokens @ K_W_norm
V = tokens @ V_W_norm

print("Q")
print(Q)
print("K")
print(K)
print("V")
print(V)
print("--------------------")

K = normalize(K)
Q = normalize(Q)

# Step 3: Compute attention using softmax
K_T = K.T
attention_scores = Q @ K_T #/ np.sqrt(DMODEL)


print("attention scores")
print(attention_scores)

attention_weights = softmax(attention_scores)

print(attention_weights)

V = normalize(V)
attention_output = attention_weights @ V

print(attention_output.astype(int))

# Step 4: Normalize output tokens back into the range 0-15
#output_tokens = np.clip(np.round((attention_output - np.min(attention_output)) / (np.max(attention_output) - np.min(attention_output)) * 15), 0, 15).astype(int)

# Output tokens
#print(output_tokens)


Q
[[2244.42044846 1871.17444595 1626.65217674 ... 2049.65526417
  1857.73922742 1650.53865932]
 [2337.04169811 1787.69680905 1188.42695098 ... 2111.89434573
  1790.3390966  1747.37944052]
 [2475.3836782  1950.63985297 1566.43970922 ... 1897.35532207
  2039.96764183 1827.42403958]
 ...
 [2094.21984552 1748.99216755 1409.3216657  ... 1795.43562964
  1831.2698499  1592.66454053]
 [2507.40023367 2069.60273285 1500.88066401 ... 2085.687243
  1717.94926797 1937.96790285]
 [2494.18244302 2122.34128512 1671.54415902 ... 2178.31162584
  2044.40825673 1920.21657669]]
K
[[1870.46419789 1992.24068826 2209.1785537  ... 1771.10859117
  2149.96846522 1850.54380197]
 [1671.26849387 1836.7360924  2104.36467923 ... 1665.74400074
  2198.83750536 2106.18002996]
 [1797.49374106 2104.14023786 2296.865253   ... 1882.06542721
  2129.01504162 1882.41232193]
 ...
 [1705.05378325 1939.91823869 1777.59475294 ... 1602.28063355
  2083.57044168 1919.68680703]
 [1778.50243318 1856.52104674 2276.38607139 ... 1950.8190

In [6]:
# save 10 input output pairs to file
# as N*DMODEL

file = open("generate_tests_input32.txt", "w")
file2 = open("generate_tests_output32.txt", "w")

np.random.seed(0)

for i in range(10):
    tokens = np.random.rand(N, DMODEL) * 15
    Q = normalize(tokens) @ normalize(Q_W.T)
    K = normalize(tokens) @ normalize(K_W.T)
    V = normalize(tokens) @ normalize(V_W.T)

    K_T = K.T
    attention_scores = Q @ K_T #/ np.sqrt(DMODEL)

    attention_weights = softmax(attention_scores)

    V = normalize(V)
    attention_output = attention_weights @ V

    # print max in each row
    print("max index in each row")

    for j in range(len(attention_scores)):
        print(np.argmax(attention_scores[j]))

    print("--------------------")
    print_matrix(V.astype(int))
    print("--------------------")

    flat_tokens = tokens.flatten()
    flat_output = attention_output.flatten()

    for j in range(len(flat_tokens)):
        file.write(str(int(flat_tokens[j])) + "\n")
    #file.write("\n")

    for j in range(len(flat_output)):
        file2.write(str(int(flat_output[j])) + "\n")
    #file2.write("\n")

file.close()
file2.close()

max index in each row
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
--------------------
10 0 12 13 12 13 12 10 12 8 8 11 13 13 14 13 7 6 12 3 13 9 13 16 13 9 14 10 9 11 13 12 
8 0 11 11 11 10 11 5 13 11 5 8 6 13 11 11 3 3 8 2 10 12 12 12 11 12 16 9 3 9 11 5 
11 0 8 11 9 14 10 5 14 10 8 10 9 16 11 14 5 7 10 2 15 11 13 15 11 11 13 9 8 8 9 10 
11 0 9 9 10 11 11 6 9 7 6 9 11 14 13 12 6 5 11 5 12 9 11 13 13 9 16 10 7 11 10 10 
14 0 11 11 12 13 12 10 9 8 10 12 10 15 14 14 6 10 12 4 16 9 15 15 14 9 15 10 8 10 13 11 
9 0 4 9 4 8 5 10 3 5 0 7 6 12 9 10 4 0 12 1 16 8 12 13 11 10 15 6 9 10 4 8 
12 0 5 9 6 13 11 9 7 9 4 9 7 12 13 10 7 5 9 4 11 6 12 14 12 8 16 11 8 8 10 10 
12 0 9 12 10 13 12 7 9 12 9 11 9 14 13 9 8 9 11 3 12 6 15 11 15 12 16 9 9 15 9 9 
12 0 9 10 12 12 10 11 11 11 7 11 10 14 14 13 8 6 11 5 15 9 14 16 13 10 11 10 8 12 9 11 
12 0 11 10 12 13 15 7 13 12 7 10 6 16 14 12 3 9 10 1 14 12 15 14 12 13 14 13 6 11 13 8 
12 0 8 10 7 11 13 7 8 10 8 7 10 15 12 12 5 1 10 4 14 