## What if we make the random projection hashable?

Inspired by [HashAttention: Semantic Sparsity for Faster Inference](https://arxiv.org/pdf/2412.14468) by Aditya Desai et al.
### Random hashing:

We hash $Q$ and $K$ into $B$ buckets by using a random matrix $R \in \mathbb{R}^{d \times b/2}$ and taking $\argmax([xR; −xR])$. As per the Johnson-Lindenstrauss lemma, this maps $Q$ and $K$ into a $b/2$-dimensional space in a relative-distance preserving way.

Effectively, our vector $\vec{x} \in \mathbb{R}^d$ is projected to hash vector $\vec{x}_h \in \mathbb{R}^{b/2}$, and we take its dimension with the largest component as the hash index.

### Learnable hashing:

The model can learn its own hash functions that cluster together attendant queries/keys based on semantic similarity rather than vector distance (e.g this does processing work). We want to encourage two things:
1. Downstream performance
2. Balancing loss (to prevent degenerate solutions where all q/ks are mapped to the same hash index)
    a. Entropy regularization?
    b. Regularization on projection matrix (L1) to prevent one dim from dominating?

In [None]:
import torch
from torch.nn import functional as F


In [None]:
def vector_hash_fn(x: torch.Tensor, num_buckets: int, R: torch.Tensor) -> torch.Tensor:
    """
    x: (..., D)
    R: (D, b/2)
    """
    D = x.shape[-1]
    assert R.shape == (D, num_buckets // 2)
    return torch.argmax(torch.cat([x @ R, -x @ R], dim=-1), dim=-1)

def get_vector_hash(D: int, num_buckets: int, device: torch.device = "cpu", dtype: torch.dtype = torch.bfloat16) -> torch.Tensor:
    R = torch.randn(D, num_buckets // 2, device = device, dtype = dtype)
    return lambda x: vector_hash_fn(x, num_buckets, R)

vector_hash = get_vector_hash(D = 10, num_buckets = 10)
vector_hash(torch.randn(50, 10, dtype = torch.bfloat16))