# Euclidean Approximate Attention - Exploration Notebook

This notebook provides interactive exploration of the key concepts:
1. Distance matrix approximation
2. Euclidean vs Dot-Product attention
3. Linear-time approximation quality


In [None]:
import sys
sys.path.insert(0, '../src')

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

# Device setup
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'
print(f'Using device: {device}')


## 1. Understanding Distance Matrices

A distance matrix D[i,j] = ||x_i - x_j||Â² captures pairwise distances.


In [None]:
from distance_estimators import compute_squared_euclidean_distance_matrix, LowRankDistanceApproximator

# Generate some random embeddings
n, d = 64, 32  # 64 tokens, 32 dimensions
X = torch.randn(n, d)

# Compute exact distance matrix
D = compute_squared_euclidean_distance_matrix(X)

print(f'X shape: {X.shape}')
print(f'D shape: {D.shape}')
print(f'D min: {D.min():.4f}, D max: {D.max():.4f}')

# Visualize
plt.figure(figsize=(8, 6))
plt.imshow(D.numpy(), cmap='viridis')
plt.colorbar()
plt.title('Squared Euclidean Distance Matrix')
plt.xlabel('Token j')
plt.ylabel('Token i')
plt.show()


## 2. Low-Rank Approximation

Key insight from Indyk et al.: Distance matrices can be approximated with sublinear samples!


In [None]:
# Try different ranks
ranks = [4, 8, 16, 32]
errors = []

fig, axes = plt.subplots(1, len(ranks) + 1, figsize=(4 * (len(ranks) + 1), 3))

# Original
axes[0].imshow(D.numpy(), cmap='viridis')
axes[0].set_title('Exact D')
axes[0].axis('off')

for i, rank in enumerate(ranks):
    approx = LowRankDistanceApproximator(rank=rank, epsilon=0.1)
    D_approx = approx(X) ** 2  # Square since method returns distances
    
    rel_error = torch.norm(D - D_approx, 'fro') / torch.norm(D, 'fro')
    errors.append(rel_error.item())
    
    axes[i + 1].imshow(D_approx.numpy(), cmap='viridis')
    axes[i + 1].set_title(f'Rank {rank}\nError: {rel_error:.3f}')
    axes[i + 1].axis('off')

plt.tight_layout()
plt.show()

print(f'Errors by rank: {dict(zip(ranks, errors))}')


## 3. Attention Mechanisms Comparison


In [None]:
from attention import StandardAttention, EuclideanAttention, ApproximateEuclideanAttention

# Create input
batch_size, seq_len, embed_dim = 1, 32, 64
num_heads = 4
X = torch.randn(batch_size, seq_len, embed_dim)

# Initialize attention
std_attn = StandardAttention(embed_dim, num_heads)
euc_attn = EuclideanAttention(embed_dim, num_heads)
approx_attn = ApproximateEuclideanAttention(embed_dim, num_heads, num_landmarks=8)

# Share weights
with torch.no_grad():
    euc_attn.q_proj.weight.copy_(std_attn.q_proj.weight)
    euc_attn.k_proj.weight.copy_(std_attn.k_proj.weight)
    euc_attn.v_proj.weight.copy_(std_attn.v_proj.weight)
    approx_attn.q_proj.weight.copy_(std_attn.q_proj.weight)
    approx_attn.k_proj.weight.copy_(std_attn.k_proj.weight)
    approx_attn.v_proj.weight.copy_(std_attn.v_proj.weight)

# Compute attention
with torch.no_grad():
    out_std, attn_std = std_attn(X, return_attention=True)
    out_euc, attn_euc = euc_attn(X, return_attention=True)
    out_approx, _ = approx_attn(X)

# Visualize
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
axes[0].imshow(attn_std[0, 0].numpy(), cmap='viridis')
axes[0].set_title('Standard (Dot-Product) Attention')
axes[1].imshow(attn_euc[0, 0].numpy(), cmap='viridis')
axes[1].set_title('Euclidean Attention')
plt.tight_layout()
plt.show()

# Compare outputs
cos_sim = F.cosine_similarity(out_std.flatten().unsqueeze(0), out_euc.flatten().unsqueeze(0)).item()
print(f'Output cosine similarity (std vs euc): {cos_sim:.4f}')
