Speed up fast_model by implementing single headed self attention using a triton kernel.

In [3]:
import os

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.6"

In [4]:
import jax.numpy as jnp
import jax

import torch

import triton
import triton.language as tl

import jax_triton as jt

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
from nimblegpt import get_config_for
from nimblegpt.fast_model import FSingleHeadCausalSelfAttention, FGPT

In [6]:
config = get_config_for('gpt2')

# Softmax Trick

In the block kernel, we accumulate the final output incrementally in pieces. When a new piece is computed, we must 'undo' and 'redo' the softmax on the current accumulation, so as to use the new maximum and normalization factor. Suppose we have two attention vector pieces $\bm{a}^{(1)}$ and $\bm{a}^{(2)}$ and two corresponding blocks of the value matrix $\bm{V}_1$ and $\bm{V}_2$. Let $\bm{v}^{(1)}$ and $\bm{v}^{(2)}$ denote the first columns of $\bm{V}_1$ and $\bm{V}_2$ respectively. By focusing on the first column, we examine a single element of the output vector - but the result is generalizable to the entire output vector.

Define $m^{(1)}$ and $m^{(2)}$ to be the maximum of $\bm{a}^{(1)}$ and $\bm{a}^{(2)}$ respectively, and similarly $\ell^{(1)}$ and $\ell^{(2)}$ to be the normalization factors.

Suppose we have already computed the softmax-dot-V for each block:

$$
\begin{align}
    y_1 &= \frac{e^{\bm{a}^{(1)} - m^{(1)}}}{\ell_1} \cdot \bm{v}^{(1)}  = \frac{1}{\ell^{(1)}} \sum e^{a^{(1)}_i - m^{(1)}} \cdot v^{(1)}_i\\
    y_2 &= \frac{e^{\bm{a}^{(2)} - m^{(2)}}}{\ell_2} \cdot \bm{v}^{(2)} = \frac{1}{\ell^{(2)}} \sum e^{a^{(2)}_i - m^{(2)}} \cdot v^{(2)}_i
\end{align}
$$

Let $m$ and $\ell$ denote our new maximum and normalization factor:

$$
\begin{align}
    m &= \max(m^{(1)}, m^{(2)})\\
    \ell &= e^{m^{(1)} - m} \ell^{(1)} + e^{m^{(2)} - m} \ell^{(2)}
\end{align}
$$

Note that this follows since:

$$
\begin{align}
    e^{m^{(1)} - m} \ell^{(1)} + e^{m^{(2)} - m} \ell^{(2)} &= e^{m^{(1)} -m} \sum e^{a^{(1)}_i - m^{(1)}} + e^{m^{(2)} - m} \sum e^{a^{(2)}_i - m^{(2)}} \\
    &= \sum e^{a^{(1)}_i - m} + \sum e^{a^{(2)}_i - m} \\
    &= \ell
\end{align}
$$

We wish to combine $y_1$ and $y_2$ into an accumulated output, with the new maximum and normalization factor. We know that:

$$
\begin{align}
    y &= \frac{e^{\bm{a}^{(1)} - m}}{\ell} \cdot \bm{v}^{(1)} + \frac{e^{\bm{a}^{(2)} - m}}{\ell} \cdot \bm{v}^{(2)} \\
    &= \frac{\ell^{(1)}}{\ell^{(1)}} \frac{e^{m - m^{(1)}}}{e^{m - m^{(1)}}} \frac{e^{\bm{a}^{(1)} - m}}{\ell} \cdot \bm{v}^{(1)} + \frac{\ell^{(2)}}{\ell^{(2)}} \frac{e^{m - m^{(2)}}}{e^{m - m^{(2)}}} \frac{e^{\bm{a}^{(2)} - m}}{\ell} \cdot \bm{v}^{(2)} \\
    &= \frac{\ell^{(1)}}{\ell \cdot e^{m - m^{(1)}}} \frac{e^{\bm{a}^{(1)} - m^{(1)}}}{\ell^{(1)}} \cdot \bm{v}^{(1)} + \frac{\ell^{(2)}}{\ell \cdot e^{m - m^{(2)}}} \frac{e^{\bm{a}^{(2)} - m^{(2)}}}{\ell^{(2)}} \cdot \bm{v}^{(2)} \\
    &= \frac{\ell^{(1)}}{\ell \cdot e^{m - m^{(1)}}} y_1 + \frac{\ell^{(2)}}{\ell \cdot e^{m - m^{(2)}}} y_2 \\
\end{align}
$$

# Block Kernel

In [78]:
@triton.jit
def resoftmax(y1, y2, m1, m2, l1, l2):

    m = tl.where(m1 > m2, m1, m2)
    l = tl.exp(m1 - m) * l1 + tl.exp(m2 - m) * l2

    return (l1 / tl.exp(m - m1) * y1 + l2 / tl.exp(m - m2) * y2) / l, m, l

In [107]:
@triton.jit
def shcsa_block_kernel(q_ptr, K_ptr, V_ptr, seq_idx_ptr, out_ptr,
                       SM_SCALE: tl.constexpr, SEQ_LEN: tl.constexpr,
                       N_FEAT: tl.constexpr, SUBSEQ_SIZE: tl.constexpr,
                       SUBFEAT_SIZE: tl.constexpr):
    """
    Triton kernel implementing single-headed attention with causal masking for a single
    token embedding. The kernel computes `SEBSEQ_SIZE` chunks of the attention vector, 
    flash-attention style. Each kernel cell computes `SUBFEAT_SIZE` elements of the output 
    vector.

    For clarity, we call subsets of the sequence axis 'subseqs' and subsets of the feature
    axis 'subfeats'. We call a tensor a 'block' when it has shape [SUBSEQ_SIZE,] or
    [SUBSEQ_SIZE, N_FEAT], and a 'chunk' when it has size [SUBFEAT_SIZE,] or 
    [SUBSEQ_SIZE, SUBFEAT_SIZE].

    Kernel cell i comuputes out[i * CHUNK_SIZE, (i+1)* CHUNK_SIZE], by multiplying `att` 
    with v[:, i * CHUNK_SIZE, (i+1)* CHUNK_SIZE] and summing over the sequence dimension.

    As with flash-attention, the output is computed interatively in blocks. The sequence
    is split into `SEQ_LEN // SUBSEQ_SIZE` blocks of size `SUBSEQ_SIZE`. 

    Inputs
    ------
    q_ptr: [N_FEAT] - query vector for the current token.
    K_ptr: [SEQ_LEN, N_FEAT] - key matrix for the entire sequence.
    V_ptr: [SEQ_LEN, N_FEAT] - value matrix for the entire sequence.

    Output
    ------
    out_ptr: [N_FEAT] - self attention output (`att @ v`)
    """
    seq_idx = tl.load(seq_idx_ptr)

    subfeat_start = tl.program_id(0) * SUBFEAT_SIZE
    # This cell computes out[out_row_num, out_col_start: out_col_start + N_OCOLS].

    seq_idxs = tl.arange(0, SEQ_LEN)
    feat_idxs = tl.arange(0, N_FEAT)

    subfeat_idxs = tl.arange(0, SUBFEAT_SIZE) + subfeat_start

    q = tl.load(q_ptr + feat_idxs)  # [N_FEAT,]

    y_chunk_acc = tl.zeros((SUBFEAT_SIZE, ), dtype=tl.float32)
    ## Running softmax - flash-attention style.
    # Running attention maximum.
    m_acc = float("-inf")
    # Running softmax denoniator.
    l_acc = 0.0

    # Don't bother calculating attention and outputs for tokens which are masked.
    n_subseq = tl.cdiv(seq_idx + 1, SUBSEQ_SIZE)

    for subseq_i in range(0, n_subseq):
        # Each iteration, we load a [SEBSEQ_SIZE, N_FEAT] block of `K`` and compute a
        # [SUBSEQ_SIZE,] block of `att`. We mulitply this by a [SEBSEQ_SIZE, CHUNK_SIZE]
        # block of `V`` to compute a [CHUNK_SIZE,] partial-result block of `out.`

        # Index io tokens of the sequence which are processed in this block.
        subseq_idxs = tl.arange(0, SUBSEQ_SIZE) + subseq_i
        # Causal mask for sequence tokens in this block.
        subseq_mask = subseq_idxs <= seq_idx

        # Index and mask into K. Sizes [SUBSEQ_SIZE, N_FEAT].
        block_idxs = subseq_idxs[:, None] * N_FEAT + feat_idxs[None, :]
        block_mask = tl.broadcast_to(subseq_mask[:, None],
                                     (SUBSEQ_SIZE, N_FEAT))

        K_block = tl.load(K_ptr + block_idxs,
                          mask=block_mask, other=0.0)  # [SUBSEQ_SIZE, N_FEAT]

        att_block = tl.sum(q[None, :] * K_block,
                           axis=1) * SM_SCALE  # [BLOCK_SIZE,]
        catt_block = tl.where(subseq_mask, att_block,
                              float("-inf"))  # [BLOCK_SIZE,]

        max_block = tl.max(catt_block, axis=0)
        # Softmax numerator of this block.
        sm_num_block = tl.exp(catt_block - max_block)
        # Softmax denominator of this block.
        sm_den_block = tl.sum(sm_num_block, axis=0)
        sm_att_block = sm_num_block / sm_den_block

        # Load V[(block_num: block_num+1) * BLOCK_SIZE, out_col_start: out_col_start + N_OCOLS]

        # Index and mask into V. Sizes [SUBSEQ_SIZE, SUBFEAT_SIZE].
        chunk_idxs = subseq_idxs[:, None] * N_FEAT + subfeat_idxs[None, :]
        chunk_mask = tl.broadcast_to(subseq_mask[:, None],
                                     (SUBSEQ_SIZE, SUBFEAT_SIZE))

        V_chunk = tl.load(V_ptr + chunk_idxs,
                          mask=chunk_mask)  # [SUBSEQ_SIZE, SUBFEAT_SIZE]

        # Partial result of a chunk of the output.
        # ([SUBSEQ_SIZE,] -> [SUBSEQ_SIZE, 1]) * [SUBSEQ_SIZE, SUBFEAT_SIZE]
        # -> [SUBSEQ_SIZE, SUBFEAT_SIZE] -{sum}-> [SUBFEAT_SIZE,]
        out_chunk_pr = tl.sum(sm_att_block[:, None] * V_chunk, axis=0)

        y_chunk_acc, m_acc, l_acc = resoftmax(y_chunk_acc, out_chunk_pr, m_acc,
                                              max_block, l_acc, sm_den_block)

    tl.store(out_ptr + subfeat_idxs, y_chunk_acc)


In [119]:
def shcsa_block(q,
                K,
                V,
                seq_idx,
                SUBSEQ_SIZE: int = 128,
                SUBFEAT_SIZE: int = 32):
    N_FEAT = q.shape[0]

    out_shape = jax.ShapeDtypeStruct((N_FEAT, ), q.dtype)
    grid = (N_FEAT // SUBFEAT_SIZE, )

    return jt.triton_call(q,
                          K,
                          V,
                          seq_idx,
                          kernel=shcsa_block_kernel,
                          out_shape=out_shape,
                          grid=grid,
                          SM_SCALE=1.0 / N_FEAT**0.5,
                          SEQ_LEN=K.shape[0],
                          N_FEAT=N_FEAT,
                          SUBSEQ_SIZE=SUBSEQ_SIZE,
                          SUBFEAT_SIZE=SUBFEAT_SIZE)


In [120]:
class SHCSABlock(FSingleHeadCausalSelfAttention):
    subseq_size: int = 128
    subfeat_size: int = 32

    def __call__(self, x: jax.Array, seq_idx: jax.Array):

        q, K, V = self.get_qKV(x, seq_idx)
        return shcsa_block(q, K, V, seq_idx)

In [121]:
rng = jax.random.PRNGKey(0)
n_feat = config.n_embd // config.n_head
n_cntx = config.block_size
x = jax.random.normal(rng, (n_feat,))
x[:5]

Array([ 1.2799144 , -0.39865986, -0.5993886 , -0.7637496 , -0.8983587 ],      dtype=float32)

In [117]:
fshcsa_module = FSingleHeadCausalSelfAttention(n_cntx=n_cntx, n_feat = n_feat)
vars = fshcsa_module.init(rng, x, jnp.array(0))

In [81]:
fy, fvars = fshcsa_module.apply(vars, x, jnp.array(0), mutable="cache")
fy

Array([-0.76740426,  0.45793283,  0.7712825 , -0.71566945, -1.3524578 ,
       -0.02418369,  0.33236438, -0.27028525,  0.38507232,  0.7667557 ,
        0.2532947 , -0.21232384, -1.6268221 , -0.52965856,  0.79835236,
        1.1416652 , -0.0675956 , -0.43777275, -0.4003198 ,  1.12303   ,
       -0.05400813,  0.42411083, -1.8518133 ,  1.2761084 , -1.313626  ,
        0.08351888, -2.1435814 ,  1.9459411 , -0.9885833 ,  0.07802778,
        0.9368881 ,  1.2056795 , -0.74474347,  0.74293506,  0.71277654,
       -0.7897532 ,  0.46426433, -0.13851726, -1.3792179 ,  0.72289133,
        0.39826477, -0.21821803,  0.9840089 , -0.5381244 ,  0.7619692 ,
        0.31365275, -0.591218  ,  0.48691204,  0.8179485 , -0.8554354 ,
        1.0123616 , -0.7641751 ,  0.3230685 ,  0.06167361,  1.16292   ,
       -0.36336797, -0.5731621 ,  0.78331184,  0.821721  , -0.5655025 ,
       -1.2870891 ,  0.7006638 ,  0.5295207 , -0.087327  ], dtype=float32)

In [122]:
module = SHCSABlock(n_cntx=n_cntx, n_feat = n_feat)
y, _ = module.apply(vars, x, jnp.array(0), mutable="cache")
y

Array([-0.76740426,  0.45793283,  0.7712825 , -0.71566945, -1.3524578 ,
       -0.02418369,  0.33236438, -0.27028525,  0.38507232,  0.7667557 ,
        0.2532947 , -0.21232384, -1.6268221 , -0.52965856,  0.79835236,
        1.1416652 , -0.0675956 , -0.43777275, -0.4003198 ,  1.12303   ,
       -0.05400813,  0.42411083, -1.8518133 ,  1.2761084 , -1.313626  ,
        0.08351888, -2.1435814 ,  1.9459411 , -0.9885833 ,  0.07802778,
        0.9368881 ,  1.2056795 , -0.74474347,  0.74293506,  0.71277654,
       -0.7897532 ,  0.46426433, -0.13851726, -1.3792179 ,  0.72289133,
        0.39826477, -0.21821803,  0.9840089 , -0.5381244 ,  0.7619692 ,
        0.31365275, -0.591218  ,  0.48691204,  0.8179485 , -0.8554354 ,
        1.0123616 , -0.7641751 ,  0.3230685 ,  0.06167361,  1.16292   ,
       -0.36336797, -0.5731621 ,  0.78331184,  0.821721  , -0.5655025 ,
       -1.2870891 ,  0.7006638 ,  0.5295207 , -0.087327  ], dtype=float32)

In [123]:
(fy - y).max()

Array(0., dtype=float32)

# Row Kernel

In [94]:
@triton.jit
def shcsa_row_kernel(
    q_ptr,
    K_ptr,
    V_ptr,
    seq_idx_ptr,
    out_ptr,
    SM_SCALE: tl.constexpr,
    SEQ_LEN: tl.constexpr,
    N_FEAT: tl.constexpr,
):
    """
    Triton kernel implementing single-headed attention with causal masking for a single
    token embedding. The kernel computes the whole attention row in one go.

    Inputs
    ------
    q_ptr: [N_FEAT] - query vector for the current token.
    K_ptr: [SEQ_LEN, N_FEAT] - key matrix for the entire sequence.
    V_ptr: [SEQ_LEN, N_FEAT] - value matrix for the entire sequence.

    Output
    ------
    out_ptr: [N_FEAT] - self attention output (`att @ v`)
    """
    seq_idx = tl.load(seq_idx_ptr)

    seq_idxs = tl.arange(0, SEQ_LEN)
    feat_idxs = tl.arange(0, N_FEAT)
    mat_idxs = seq_idxs[:, None] * N_FEAT + feat_idxs[None, :]

    # Shape (SEQ_LEN,) mask. 0 for all indices i > seq_idx.
    causal_seq_mask = seq_idxs <= seq_idx
    # Shape (SEQ_LEN, N_FEAT) mask. 0 for all rows where rows[i] > seq_idx.
    causal_mat_mask = tl.broadcast_to(causal_seq_mask[:, None], (SEQ_LEN, N_FEAT))

    q = tl.load(q_ptr + feat_idxs)
    K = tl.load(K_ptr + mat_idxs, mask=causal_mat_mask, other=0.0)
    V = tl.load(V_ptr + mat_idxs, mask=causal_mat_mask, other=0.0)

    # ([N_FEAT,] -> [1, N_FEAT]) * [SEQ_LEN, N_FEAT] -> [SEQ_LEN, N_FEAT] -{sum}-> [SEQ_LEN,]
    # att[i] is high when token `seq_idx` should attend heavily to token i.
    att = tl.sum(q[None, :] * K, axis=1) * SM_SCALE

    causal_att = tl.where(causal_seq_mask, att, float("-inf"))

    sm_numerator = tl.exp(causal_att - tl.max(causal_att, axis=0))
    sm_att = sm_numerator / tl.sum(sm_numerator, axis=0) # [SEQ_LEN,]

    # ([SEQ_LEN,] -> [SEQ_LEN, 1]) * [SEQ_LEN, N_FEAT] -> [SEQ_LEN, N_FEAT] -{sum}-> [N_FEAT,]
    out = tl.sum(sm_att[:, None] * V, axis=0) # [N_FEAT,]

    tl.store(out_ptr + feat_idxs, out)

In [95]:
def shcsa_row(q, K, V, seq_idx):
    N_FEAT = q.shape[0]

    out_shape = jax.ShapeDtypeStruct((N_FEAT,), q.dtype)
    grid = (1,)

    return jt.triton_call(
        q, K, V, seq_idx, kernel=shcsa_row_kernel, out_shape=out_shape, grid=grid, SM_SCALE = 1.0 / N_FEAT ** 0.5, SEQ_LEN = K.shape[0], N_FEAT = N_FEAT
    )

In [96]:
class SHCSARow(FSingleHeadCausalSelfAttention):

    def __call__(self, x: jax.Array, seq_idx: jax.Array):

        q, K, V = self.get_qKV(x, seq_idx)
        return shcsa_row(q, K, V, seq_idx)

In [97]:
rng = jax.random.PRNGKey(0)
n_feat = config.n_embd // config.n_head
n_cntx = config.block_size
x = jax.random.normal(rng, (n_feat,))
x[:5]

Array([ 1.2799144 , -0.39865986, -0.5993886 , -0.7637496 , -0.8983587 ],      dtype=float32)

In [98]:
fshcsa_module = FSingleHeadCausalSelfAttention(n_cntx=n_cntx, n_feat = n_feat)
vars = fshcsa_module.init(rng, x, jnp.array(0))

In [99]:
fy, fvars = fshcsa_module.apply(vars, x, jnp.array(0), mutable="cache")
fy

Array([-0.76740426,  0.45793283,  0.7712825 , -0.71566945, -1.3524578 ,
       -0.02418369,  0.33236438, -0.27028525,  0.38507232,  0.7667557 ,
        0.2532947 , -0.21232384, -1.6268221 , -0.52965856,  0.79835236,
        1.1416652 , -0.0675956 , -0.43777275, -0.4003198 ,  1.12303   ,
       -0.05400813,  0.42411083, -1.8518133 ,  1.2761084 , -1.313626  ,
        0.08351888, -2.1435814 ,  1.9459411 , -0.9885833 ,  0.07802778,
        0.9368881 ,  1.2056795 , -0.74474347,  0.74293506,  0.71277654,
       -0.7897532 ,  0.46426433, -0.13851726, -1.3792179 ,  0.72289133,
        0.39826477, -0.21821803,  0.9840089 , -0.5381244 ,  0.7619692 ,
        0.31365275, -0.591218  ,  0.48691204,  0.8179485 , -0.8554354 ,
        1.0123616 , -0.7641751 ,  0.3230685 ,  0.06167361,  1.16292   ,
       -0.36336797, -0.5731621 ,  0.78331184,  0.821721  , -0.5655025 ,
       -1.2870891 ,  0.7006638 ,  0.5295207 , -0.087327  ], dtype=float32)

In [100]:
module = SHCSARow(n_cntx=n_cntx, n_feat = n_feat)
y, _ = module.apply(vars, x, jnp.array(0), mutable="cache")

In [101]:
y, _ = module.apply(vars, x, jnp.array(0), mutable="cache")

In [102]:
y

Array([-0.76740426,  0.45793283,  0.7712825 , -0.71566945, -1.3524578 ,
       -0.02418369,  0.33236438, -0.27028525,  0.38507232,  0.7667557 ,
        0.2532947 , -0.21232384, -1.6268221 , -0.52965856,  0.79835236,
        1.1416652 , -0.0675956 , -0.43777275, -0.4003198 ,  1.12303   ,
       -0.05400813,  0.42411083, -1.8518133 ,  1.2761084 , -1.313626  ,
        0.08351888, -2.1435814 ,  1.9459411 , -0.9885833 ,  0.07802778,
        0.9368881 ,  1.2056795 , -0.74474347,  0.74293506,  0.71277654,
       -0.7897532 ,  0.46426433, -0.13851726, -1.3792179 ,  0.72289133,
        0.39826477, -0.21821803,  0.9840089 , -0.5381244 ,  0.7619692 ,
        0.31365275, -0.591218  ,  0.48691204,  0.8179485 , -0.8554354 ,
        1.0123616 , -0.7641751 ,  0.3230685 ,  0.06167361,  1.16292   ,
       -0.36336797, -0.5731621 ,  0.78331184,  0.821721  , -0.5655025 ,
       -1.2870891 ,  0.7006638 ,  0.5295207 , -0.087327  ], dtype=float32)

In [29]:
from functools import partial


kgpt = FGPT.MakeWithSHCSA(config, partial(SHCSARow, n_cntx=n_cntx, n_feat = n_feat))
vars = kgpt.init()

In [30]:
kgpt.generate?

[0;31mSignature:[0m
[0mkgpt[0m[0;34m.[0m[0mgenerate[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mrng[0m[0;34m:[0m [0;34m<[0m[0mfunction[0m [0mPRNGKey[0m [0mat[0m [0;36m0x7f7406a63ac0[0m[0;34m>[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mvariables[0m[0;34m:[0m [0mDict[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mprompt_idxs[0m[0;34m:[0m [0mjax[0m[0;34m.[0m[0mArray[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mlogit_sampler[0m[0;34m:[0m [0mCallable[0m[0;34m[[0m[0;34m[[0m[0mPRNGKey[0m[0;34m,[0m [0mjax[0m[0;34m.[0m[0mArray[0m[0;34m][0m[0;34m,[0m [0mjax[0m[0;34m.[0m[0mArray[0m[0;34m][0m [0;34m=[0m [0;34m<[0m[0mfunction[0m [0mFGPT[0m[0;34m.[0m[0;34m<[0m[0;32mlambda[0m[0;34m>[0m [0mat[0m [0;36m0x7f740acd60e0[0m[0;34m>[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmax_new_tokens[0m[0;34m:[0m [0mint[0m [0;34m=[0m [0;36m10[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m


In [13]:
float("-inf") * 0

nan

In [14]:
fshcsa_module.apply(vars, x, jnp.array(0), mutable="cache", method="get_qKV")

((Array([-6.51827455e-02, -1.64756298e+00,  8.68667841e-01,  1.08457220e+00,
         -7.01299548e-01, -5.72350860e-01, -3.95007074e-01,  7.66254127e-01,
         -7.61115551e-01,  1.54181850e+00,  1.33208990e-01, -1.61614037e+00,
          1.48971975e-02,  6.72358334e-01, -1.21386719e+00, -5.99314332e-01,
         -3.89698446e-01,  9.63883162e-01,  2.28442162e-01,  9.72948968e-02,
         -1.19103360e+00,  6.08899295e-02,  9.51677740e-01,  9.33376431e-01,
          1.76112592e-01, -4.33435768e-01,  3.72164249e-02,  3.05139422e-02,
         -4.05094266e-01,  1.25109792e+00, -1.74354315e-02,  7.83537507e-01,
         -3.58243942e-01,  1.08284950e-01, -7.79186428e-01,  9.94483590e-01,
         -6.30619824e-01, -1.08396506e+00,  4.82888997e-01, -9.37316477e-01,
         -9.34945345e-02, -3.83898556e-01,  5.94781935e-01,  1.69102609e-01,
         -1.19115740e-01, -5.89539051e-01, -1.92776787e+00, -1.27550960e-03,
          1.99623942e+00, -3.00156474e-01,  6.18577838e-01, -1.29633510e+00,

In [15]:
sm_att = torch.concatenate((torch.tensor([1]), torch.zeros(n_cntx - 1)))
sm_att

tensor([1., 0., 0.,  ..., 0., 0., 0.])

In [16]:
V = torch.randn((n_cntx, n_feat))
V.shape, V

(torch.Size([1024, 64]),
 tensor([[ 1.7014, -0.1726, -1.6793,  ..., -0.7675, -0.1582, -0.8307],
         [-0.3748, -0.1536,  0.6626,  ..., -1.0473, -0.4401,  1.9859],
         [-0.9378, -1.0736, -1.6285,  ...,  0.8539, -0.6150, -1.1558],
         ...,
         [ 1.0742, -0.6377, -0.2379,  ..., -1.6558, -0.9871,  0.6652],
         [-1.0223, -0.3294,  0.1540,  ...,  0.0216,  0.1360,  0.0924],
         [ 1.4035,  1.7089, -2.7962,  ..., -1.1578, -0.3323,  0.0676]]))

In [17]:
sm_att[:, None] * V

tensor([[ 1.7014, -0.1726, -1.6793,  ..., -0.7675, -0.1582, -0.8307],
        [-0.0000, -0.0000,  0.0000,  ..., -0.0000, -0.0000,  0.0000],
        [-0.0000, -0.0000, -0.0000,  ...,  0.0000, -0.0000, -0.0000],
        ...,
        [ 0.0000, -0.0000, -0.0000,  ..., -0.0000, -0.0000,  0.0000],
        [-0.0000, -0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.0000,  ..., -0.0000, -0.0000,  0.0000]])

In [18]:
torch.sum(sm_att[:, None] * V, axis=0)

tensor([ 1.7014, -0.1726, -1.6793, -2.0957,  0.7844, -1.2923, -1.1041,  1.0879,
         0.3221,  0.9796,  0.2768,  0.6901,  0.0898, -0.4092, -0.2104, -0.5178,
        -0.8111,  0.7136,  1.5241, -1.2647, -0.2747, -0.1076,  1.0560,  1.4402,
        -1.5324, -0.1405,  0.6247,  1.4614,  0.4163,  1.4627, -0.7094,  1.5771,
        -1.3416, -0.2051,  0.3650,  0.4769,  0.3391, -0.2098, -0.6240, -0.1273,
         0.4322, -0.7927, -0.2393, -1.2307,  0.9321, -1.0986,  0.7401,  0.9525,
         1.1932,  0.2744, -0.7200, -2.6377, -1.2093,  1.2528,  0.4161, -1.4054,
        -0.4533, -0.7188, -0.4449,  0.8719,  0.9343, -0.7675, -0.1582, -0.8307])