In [44]:
import torch

def calculate_attention(Q, K, V, B):
    n, d = Q.size()

    # Calculate the dot product of Q and K
    attention_scores = torch.matmul(Q, K.transpose(-2, -1))

    # Normalize the attention scores by dividing by d.
    attention_scores = attention_scores / d

    # Numerical stability
    attention_scores -= B*B
    
    # Apply softmax to get the attention weights
    attention_weights = torch.softmax(attention_scores, dim=-1)
    
    # Multiply the attention weights with the value vectors
    attention_output = torch.matmul(attention_weights, V)
    
    return attention_output

In [45]:
# Generate 50x50 random matrices with elements in range [-10, 10]
Q = torch.randint(-10, 11, (50, 50)).float()
K = torch.randint(-10, 11, (50, 50)).float()
V = torch.randint(-10, 11, (50, 50)).float()

# Call the calculate_attention function
attention_output = calculate_attention(Q, K, V, B=10.0)

# Print the attention output
print(attention_output) 


tensor([[-7.5545e+00, -1.0099e+00, -7.7820e+00,  ...,  6.8297e+00,
          7.6803e+00,  7.6429e+00],
        [-7.4637e+00, -4.9500e+00,  4.2071e+00,  ...,  5.0288e+00,
         -6.3750e-03, -7.4135e+00],
        [ 1.1425e+00, -5.0924e-01, -1.3612e+00,  ...,  1.9775e+00,
         -4.6699e+00,  1.2864e+00],
        ...,
        [ 4.9724e+00,  3.8048e-02, -9.8822e+00,  ...,  4.9183e+00,
          3.0044e+00,  4.9205e+00],
        [ 5.7137e+00, -4.0068e+00,  1.7566e+00,  ..., -3.7334e-01,
          2.4066e+00,  2.5427e+00],
        [-2.9471e+00,  1.0142e+00, -5.5731e+00,  ...,  6.0247e+00,
          8.5180e+00, -1.2846e+00]])


In [32]:
# 
# Find top-k indices of inner products q^T k_i for every query q and set of keys K
# This is done naively here, in O(n^2) time
# 
def topk_indices_naive(Q, K, k, B):
    n, d = K.size()

    # Calculate the dot product of Q and K
    attention_scores = torch.matmul(Q, K.transpose(-2, -1)) 

    # Normalize the attention scores by dividing by d.
    attention_scores = attention_scores / d

    # Apply the exponential function to the attention scores
    attention_scores = torch.exp(attention_scores - B * B)
    
    # Find the top-k indices of the attention scores
    topk_scores, topk_indices = torch.topk(attention_scores, k, dim=-1)
    
    return topk_scores, topk_indices

In [33]:
# Example:
q = torch.tensor([[1, 0, 0, 1]])
K = torch.tensor([[1, 2, 3, 9], [4, 5, 6, 8], [7, 8, 9, 1]])

topk_indices_naive(q, K, 2, 10)

(tensor([[7.4689e-43, 4.5262e-43]]), tensor([[1, 0]]))

In [34]:
import numpy

#
# Approximating the expected value of a value vector v with underlying distribution
# p(i) = softmax(q^T k_i) / Z, where Z is the partition function
# 
# Parameters:
#   q is the query vector: 1 x d
#   K is the matrix of key vectors: n x d
#   v is the value vector: 1 x n
#   k is the number of top-k elements to consider
#   l is the number of samples to draw from the remaining elements
#   topk_indices_func is a function that returns the top-k indices of the attention scores
#   B is the maximum value of the q,k,v elements. We use this to avoid numerical instability.
# 
def approximate_softmax_expectation(q, K, v, k, l, topk_indices_func, B):
    n, d = K.size()

    # Find the top-k indices of the attention scores
    scores, indices = topk_indices_func(q, K, k, B)

    # From the n-k remaining elements, draw l samples.
    # TODO: Implement this in O(l) time.
    remaining_indices = set(range(n)) - set(indices)

    # Randomly sample l indices from the remaining elements
    random_indices = numpy.random.choice(list(remaining_indices), l, replace=False)

    # Now we'll evaluate the partition function and the expectation separately.

    approx_partition = 0
    approx_expectation = 0
    for index in random_indices:
        # Calculate the attention score for the remaining elements
        attention_score = torch.exp(torch.dot(q, K[index]) / d - B*B)

        # Add the attention score to the partition function
        approx_partition += attention_score

        # Add the attention score times the value to the expectation
        approx_expectation += attention_score * v[index]

    approx_partition *= ((n-k) / l)
    approx_expectation *= ((n-k) / l)

    approx_partition += scores.sum()
    approx_expectation += torch.sum(scores * v[indices[0]])

    # Return the approximate softmax partition function.
    return approx_expectation / approx_partition


In [35]:
# 
# Approximates the attention output using sampling.
# 
# Parameters:
#   Q is the matrix of query vectors: n x d
#   K is the matrix of key vectors: n x d
#   V is the matrix of value vectors: n x d
#   k is the number of top-k elements to consider
#   l is the number of samples to draw from the remaining elements
#   topk_indices_func is the function to find the top-k indices
def sampling_attention(Q, K, V, k, l, topk_indices_func, B):

    output = torch.zeros_like(V)

    # For all rows in Q...
    for i in range(Q.size(0)): # n
        # For all columns in V...
        for j in range(V.size(1)): # d
            # Approximate the expected value of the value vector
            output[i, j] = approximate_softmax_expectation(Q[i], K, V[i], k, l, topk_indices_func, B)

    return output

In [46]:
import time

# Generate 50x50 random matrices with elements in range [-10, 10]
Q = torch.randint(-10, 11, (200, 200)).float()
K = torch.randint(-10, 11, (200, 200)).float()
V = torch.randint(-10, 11, (200, 200)).float()

# Call the sampling_attention function.
# Measure the time it takes to run the function.
start_time = time.time()
attention_output = sampling_attention(Q, K, V, 50, 50, topk_indices_naive, B=10.0)
end_time = time.time()
print("Time taken for approximate attention:", end_time - start_time)

# Compare with the exact attention output
# Calculate the time taken to run the calculate_attention function.
start_time = time.time()
exact_attention_output = calculate_attention(Q, K, V,B=10.0)
end_time = time.time()
print("Time taken for exact attention:", end_time - start_time)

# Print the mean absolute error
print("Mean error: ", torch.mean(torch.abs(attention_output - exact_attention_output)))

Time taken for approximate attention: 43.3422908782959
Time taken for exact attention: 0.008199930191040039
Mean error:  tensor(4.2154)


In [47]:
class angularLSH:
    def __init__ (self, K, r, c, B):
        self.K = K
        self.c = c
        self.r = r
        self.n, self.d = K.size()

        # Normalize the key vectors by dividing with d:
        self.K /= self.d

        # Normalize further by dividing all the key vectors by B*B
        self.K /= (B*B)

        # Now add an extra dimension to the key vectors to 
        # make them all have norm 1. Now all the key vectors have d+1 dimensions.
        self.K = torch.cat((self.K, torch.sqrt(1 - torch.sum(self.K**2, dim=1, keepdim=True))), dim=1)
        self.d += 1

        # LSH works as follows:
        # 1. To hash a single vector, concatenate k = O(log n) smaller hashes. 
        #    That's one hash table. That means we have 2^k = O(n) buckets
        # 2. Maintain L = n^s hash tables where s = log(1-arccos(b)/pi) / log(1-arccos(s)/pi)
        self.k = int(torch.ceil(torch.log2(self.n)))
        self.L = torch.log(1-torch.arccos(self.c * self.r)/torch.pi) / torch.log(1-torch.arccos(self.r)/torch.pi)

        # A hash is just the sign of the dot product of the vector 
        # with a random vector on the unit sphere.
        # (Coordinates are drawn from a Gaussian distribution.)
        self.hash_vectors = torch.randn(self.L, self.k, self.d)

        # We'll store L hash tables, each with 2^k buckets. Each bucket
        # will store a list of indices of the key vectors.
        self.hash_tables = [{} for _ in range(self.L)]

        # Now we hash all the key vectors L times:
        for i in range(self.n):
            ki = self.K[i] # Get the i-th key vector: 1 x d

            # Compute the hash of the key vector for all L hash tables.
            for j in range(self.L):
                hj = self.hash(ki, self.hash_vectors[j]) # number between 0 and 2^k - 1

                # Add the index of the key vector to the corresponding bucket.
                if hj not in self.hash_tables[j]:
                    self.hash_tables[j][hj] = []

                self.hash_tables[j][hj].append(i)

    #
    # Hash a single vector x using the hash function h.
    # h consists of k random vectors on the unit sphere.
    #
    # Parameters:
    #  x is the input vector: 1 x (d+1)
    #  h is the hash function: k x (d+1)
    # 
    # Returns:
    #  The hash of the input vector x: a number between 0 and 2^k - 1
    # 
    def hash(self, x, h):
        # If the input vector x has d dimensions, we add an extra dimension of 0.
        # This is needed for the query vectors, but not the key vectors because
        # we already added an extra dimension to them.
        if len(x) == self.d - 1:
            x = torch.cat((x, torch.tensor([0.0])))

        hash_value = 0
        for i in range(self.k):
            if torch.dot(x, h[i]) >= 0:
                hash_value += 2**i

        return hash_value

    # Function to query how many (and which) key vectors are in the same bucket 
    # as the query vector q.
    # 
    # We limit the number of returned key vectors to max_results.
    # We also require the keys to have a dot product of at most cr with the query vector.
    #
    # Runtime: O(max(max_results, number of keys in the bucket))
    def query_bucket_size(self, q, max_results):

        # If the query vector q has d dimensions, we add an extra dimension of 0.
        if len(q) == self.d - 1:
            q = torch.cat((q, torch.tensor([0.0])))

        distinct_keys = set()
        for j in range(self.L):
            hj = self.hash(q, self.hash_vectors[j])

            if hj in self.hash_tables[j]:
                for key_index in self.hash_tables[j][hj]:

                    # Check if the key vector has a dot product of at most cr with the query vector.
                    if torch.dot(q, self.K[key_index]) >= self.c * self.r:
                        continue
                    
                    # Add the key index to the set of distinct keys.
                    distinct_keys.add(key_index)

                    # If we have reached the maximum number of results, return the keys.
                    if len(distinct_keys) >= max_results:
                        return distinct_keys

        return distinct_keys, len(distinct_keys)


In [None]:
# We will create a sequence of concentric LSH objects with increasing r values.
def topk_indices_lsh_preprocessing(K, B):
    n, d = K.size()

    c = d/(B*B*B*B)
    current_s = 0
    current_b = c * s



In [None]:
# This function finds the top-k inner products q^T k_i for a query q and a set of 
# keys K. It does so in O(k) time by using the concentric circle LSH idea.
def topk_indices_fast_lsh(q, K, k, B):
    pass