In [2]:
import torch

def calculate_attention(Q, K, V):
    # Calculate the dot product of Q and K
    attention_scores = torch.matmul(Q, K.transpose(-2, -1))
    
    # Scale the attention scores by the square root of the dimension of the key vectors
    attention_scores = attention_scores / torch.sqrt(torch.tensor(K.size(-1), dtype=torch.float32))
    
    # Apply softmax to get the attention weights
    attention_weights = torch.softmax(attention_scores, dim=-1)
    
    # Multiply the attention weights with the value vectors
    attention_output = torch.matmul(attention_weights, V)
    
    return attention_output

In [5]:
# Generate 50x50 random matrices with elements in range [-10, 10]
Q = torch.randint(-10, 11, (50, 50)).float()
K = torch.randint(-10, 11, (50, 50)).float()
V = torch.randint(-10, 11, (50, 50)).float()

# Call the calculate_attention function
attention_output = calculate_attention(Q, K, V)

# Print the attention output
print(attention_output) 


tensor([[  2.0000,  -2.0000,  -4.0000,  ...,  -2.0000,   6.0000,  -6.0000],
        [ -5.0000,  -3.0000, -10.0000,  ...,  -1.0000,   1.0000,   7.0000],
        [ -1.4879,  -0.0237,   2.0490,  ...,   4.3646,   1.0478,   1.0248],
        ...,
        [ -5.9965,  -6.0050,  -1.0085,  ...,  -5.9949,  -1.0034,   8.9879],
        [ -3.0000,   3.0000,  -4.0000,  ..., -10.0000,  -5.0000,  -2.0000],
        [ -8.9995,  -8.9964,  -4.9981,  ...,  -1.9978,   9.9990,   6.9994]])


In [34]:
# 
# Find top-k indices of inner products q^T k_i for a given query q and set of keys K
# 
def topk_indices(Q, K, k):
    # Calculate the dot product of Q and K
    attention_scores = torch.matmul(Q, K.transpose(-2, -1))

    # Apply the exponential function to the attention scores
    attention_scores = torch.exp(attention_scores)
    
    # Find the top-k indices of the attention scores
    topk_scores, topk_indices = torch.topk(attention_scores, k, dim=-1)
    
    return topk_scores, topk_indices

In [35]:
# Example:
q = torch.tensor([[1, 0, 0, 1]])
K = torch.tensor([[1, 2, 3, 9], [4, 5, 6, 8], [7, 8, 9, 1]])

topk_indices(q, K, 2)  # Output: tensor([[1, 0]])

(tensor([[162754.7969,  22026.4648]]), tensor([[1, 0]]))

In [32]:
import numpy

#
# Approximating the softmax partition function for a single query using the top-k method
# q is the query vector
# K is the matrix of key vectors: n x d
# k is the number of top-k elements to consider
# l is the number of samples to draw from the remaining elements
# 
def approximate_softmax(q, K, k, l):
    n, d = K.size()

    # Find the top-k indices of the attention scores
    scores, indices = topk_indices(q, K, k)

    # Take the partition sum of the top-k attention scores
    partition_sum = scores.sum(dim=-1)

    # From the n-k remaining elements, draw l samples.
    remaining_indices = set(range(n)) - set(indices[0].tolist())

    # Randomly sample l indices from the remaining elements
    random_indices = numpy.random.choice(list(remaining_indices), l, replace=False)

    # Calculate the attention scores of the remaining elements
    remaining_sum = torch.exp(torch.matmul(q, K[random_indices, :].transpose(-2, -1)))

    # Return the approximate softmax partition function.
    return (n-k) * (remaining_sum.sum() / l) + partition_sum.item()


In [33]:
# Example:
q = torch.tensor([[1, 0, 0, 1]])
K = torch.tensor([[1, 2, 3, 9], [4, 5, 6, 8], [7, 8, 9, 1], [2, 3, 4, 5], [6, 7, 8, 9]])
print(approximate_softmax(q, K, 1, 2))

# Compare with the true softmax partition function
attention_scores = torch.exp(torch.matmul(q, K.transpose(-2, -1)))
true_softmax = attention_scores.sum()
true_softmax 

tensor(3315263.5000)


tensor(3457876.2500)

In [None]:
# 
# Approximating attention with top-k method
# 


