# Simple scaled dot-product attention

## Imports and target

In [None]:
import math
import numpy as np
import tvm
from tvm import te, tir
import tvm.testing

target = tvm.target.Target("llvm")
dev = tvm.device(target.kind.name, 0)

## Tensor expression for attention

In [None]:
SEQ_LEN = 1024
HEAD_DIM = 1024
DTYPE = "float32"
SCALE = 1.0 / math.sqrt(HEAD_DIM)

Q = te.placeholder((SEQ_LEN, HEAD_DIM), name="Q", dtype=DTYPE)
K = te.placeholder((SEQ_LEN, HEAD_DIM), name="K", dtype=DTYPE)
V = te.placeholder((SEQ_LEN, HEAD_DIM), name="V", dtype=DTYPE)

k = te.reduce_axis((0, HEAD_DIM), name="k")
scores_unscaled = te.compute(
    (SEQ_LEN, SEQ_LEN),
    lambda i, j: te.sum(Q[i, k] * K[j, k], axis=k),
    name="scores_unscaled",
)
scores = te.compute(
    scores_unscaled.shape,
    lambda i, j: scores_unscaled[i, j] * SCALE,
    name="scores",
)

score_exp = te.compute(
    scores.shape,
    lambda i, j: tir.exp(scores[i, j]),
    name="score_exp",
)

j_softmax = te.reduce_axis((0, SEQ_LEN), name="j_softmax")
row_sum = te.compute(
    (SEQ_LEN,),
    lambda i: te.sum(score_exp[i, j_softmax], axis=j_softmax),
    name="row_sum",
)

prob = te.compute(
    scores.shape,
    lambda i, j: score_exp[i, j] / row_sum[i],
    name="prob",
)

jv = te.reduce_axis((0, SEQ_LEN), name="jv")
attention = te.compute(
    (SEQ_LEN, HEAD_DIM),
    lambda i, k: te.sum(prob[i, jv] * V[jv, k], axis=jv),
    name="attention",
)

prim_func = te.create_prim_func([Q, K, V, attention]).with_attr("global_symbol", "attention")
mod = tvm.ir.IRModule({"attention": prim_func})
lib = tvm.build(mod, target=target)
func = lib["attention"]


## Build and validate

In [None]:
rng = np.random.default_rng(0)
q_np = rng.standard_normal((SEQ_LEN, HEAD_DIM), dtype=np.float32)
k_np = rng.standard_normal((SEQ_LEN, HEAD_DIM), dtype=np.float32)
v_np = rng.standard_normal((SEQ_LEN, HEAD_DIM), dtype=np.float32)

q_tvm = tvm.runtime.tensor(q_np.astype(DTYPE), dev)
k_tvm = tvm.runtime.tensor(k_np.astype(DTYPE), dev)
v_tvm = tvm.runtime.tensor(v_np.astype(DTYPE), dev)
out_tvm = tvm.runtime.tensor(np.empty((SEQ_LEN, HEAD_DIM), dtype=DTYPE), dev)

func(q_tvm, k_tvm, v_tvm, out_tvm)

scores_np = (q_np @ k_np.T) * SCALE
weights_np = np.exp(scores_np)
weights_np /= weights_np.sum(axis=1, keepdims=True)
answer = weights_np @ v_np
tvm.testing.assert_allclose(out_tvm.numpy(), answer, rtol=1e-3, atol=1e-3)

In [None]:
def evaluate_operation(lib, optimization, log):
    func = lib["attention"]
    assert func

    out_tvm = tvm.runtime.tensor(np.zeros((SEQ_LEN, HEAD_DIM), dtype=DTYPE), dev)
    func(q_tvm, k_tvm, v_tvm, out_tvm)
    tvm.testing.assert_allclose(out_tvm.numpy(), answer,rtol=1e-3, atol=1e-3)

    evaluator = lib.time_evaluator(lib.entry_name, dev, number=10)
    mean_time = evaluator(q_tvm, k_tvm, v_tvm, out_tvm).mean
    print("%s: %f" % (optimization, mean_time))
    log.append((optimization, mean_time))

In [None]:
log = []

evaluate_operation(lib, "none", log)

## Scheduling optimization
TODO:

# Summary

In [None]:
baseline = log[0][1]
print("%s\t%s\t%s" % ("Operator".rjust(20), "Timing".rjust(20), "Performance".rjust(20)))
for result in log:
    print(
        "%s\t%s\t%s"
        % (result[0].rjust(20), str(result[1]).rjust(20), str(result[1] / baseline).rjust(20))
    )