In [1]:
import plotly.express as px
import numpy as np

In [22]:
def draw_mask(mask):
    fig = px.imshow(mask, color_continuous_scale='blues')
    fig.show()

Connectivity pattern $S = \{S_1, \cdots, S_n\}$, where $S_i$ is the set of indices of the input vectors to which the $i^{\text{th}}$ output vector attends.

In [86]:
def softmax(x):
    return np.exp(x) / np.sum(np.exp(x), axis=0)

def attend(X, S, Wq, Wk, Wv):
    n, d_emb = X.shape
    d_k = Wq.shape[1]
    d_v = Wv.shape[1]

    connectivity = np.zeros((n, d_v))
    
    for i in range(n):
        Si = S[i]
        inner = ((X[i] @ Wq) @ (X[Si] @ Wk).T) / np.sqrt(d_k)
        connectivity[i] = softmax(inner) @ (X[Si] @ Wv)
    
    return connectivity

n = 16
d_emb = 16
d_k = 16
d_v = 16

X = np.random.randn(n, d_emb)
Wq = np.random.randn(d_emb, d_k)
Wk = np.random.randn(d_emb, d_k)
Wv = np.random.randn(d_emb, d_v)

S = [list(range(i)) for i in range(n)]

connectivity = attend(X, S, Wq, Wk, Wv)
draw_mask(connectivity)

In [89]:
def full_self_attention(n):
    connectivity = np.zeros((n, n))
    
    for i in range(n):
        for j in range(i):
            connectivity[i, j] += 1

    return connectivity

draw_mask(full_self_attention(16))

In [91]:
def strided_sparse_attention(n, l):
    connectivity = np.zeros((n, n))
    
    for i in range(n):
        # Sliding attention
        t = max(0, i - l)
        connectivity[i, t:i+1] += 1
        
        # Striding attention
        strided_indices = np.arange(i % l, n, l)
        connectivity[i, strided_indices] += 1

    # I feel like I shouldn't need this RIP
    connectivity *= np.tril(np.ones((n, n)), k=0)
    return connectivity

n = 16
l = 4

connectivity_matrix = strided_sparse_attention(n, l)
draw_mask(connectivity_matrix)

In [101]:
def fixed_sparse_attention(n, l, c):
    connectivity = np.zeros((n, n))
    
    for i in range(n):
        # {j: (floor(j/l) == floor(i/l))}
        # {j: j mod l /member {t, t+1, ..., l} where t = l-c}
        
        for j in range(n):
            if j // l == i // l and (j % l) in range(l - c, l):
                connectivity[i, j] += 1

    connectivity *= np.tril(np.ones((n, n)), k=0)
    return connectivity

n = 16
l = 4
c = 4

connectivity_matrix = fixed_sparse_attention(n, l, c)
draw_mask(connectivity_matrix)