# 🧠 Mixture-of-Logits: Lý thuyết và Universal Approximation

## 🎯 Mục tiêu Học tập

Hiểu sâu về:
1. **Lý thuyết toán học** đằng sau Mixture-of-Logits (MoL)
2. **Universal Approximation Property** - tại sao MoL có thể biểu diễn bất kỳ similarity function nào
3. **Matrix Rank Theory** và mối liên hệ với expressiveness
4. **So sánh với Dot Product** và hạn chế của low-rank bottleneck

## 📖 Trích xuất từ Paper

### Section 2.1 - Key Insights:

> *"Our key insight is that learned similarity approaches are but different ways to increase the expressiveness of the retrieval stage. Formally, for a query q and an item x, the expressiveness of the similarity function boils down to deriving alternative parameterizations of p(x|q) matrices, with full rank matrices being the most expressive among them."*

> *"Dot products, on the other hand, induces a low-rank bottleneck due to the dimensionality of the embedding, i.e., ln p(x|q) ∝ ⟨f(q), g(x)⟩ (f(q), g(x) ∈ R^d)."*

### Mathematical Foundation:

**MoL Definition (Equation 1)**:
$$\phi(q,x) = \sum_{p=1}^{P} \pi_p(q,x) \langle f_p(q), g_p(x) \rangle$$

**Outer Product Form**:
$$\phi(q,x) = \sum_{p_q=1}^{P_q} \sum_{p_x=1}^{P_x} \pi_{p_q,p_x}(q,x) \left\langle \frac{f_{p_q}(q)}{||f_{p_q}(q)||_2}, \frac{g_{p_x}(x)}{||g_{p_x}(x)||_2} \right\rangle$$

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Tuple, List
import math
from scipy.linalg import svd
import warnings
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🔧 Device: {device}")

# Set style for better plots
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

## 🔍 Phần 1: Matrix Rank và Expressiveness

### 📊 Lý thuyết:

**Similarity Matrix**: Cho Q queries và X items, similarity matrix S ∈ R^(Q×X)
- **Full Rank**: rank(S) = min(Q, X) → Most Expressive
- **Low Rank**: rank(S) ≤ d (embedding dim) → Limited Expressiveness

**Dot Product Limitation**: 
- S[i,j] = f(q_i)^T g(x_j)
- S = F G^T where F ∈ R^(Q×d), G ∈ R^(X×d) 
- rank(S) ≤ min(Q, X, d) → **Bottleneck at dimension d**

In [None]:
def analyze_matrix_rank_expressiveness():
    """
    Demonstrate the relationship between matrix rank and expressiveness
    """
    print("🔍 Matrix Rank Analysis")
    print("=" * 50)
    
    # Parameters
    num_queries = 100
    num_items = 500
    embedding_dims = [16, 32, 64, 128]
    
    results = []
    
    for d in embedding_dims:
        # Generate random embeddings
        Q = np.random.randn(num_queries, d)
        X = np.random.randn(num_items, d)
        
        # Compute dot product similarity matrix
        S_dot = Q @ X.T  # [num_queries, num_items]
        
        # Compute rank
        rank_dot = np.linalg.matrix_rank(S_dot)
        max_possible_rank = min(num_queries, num_items)
        
        # Expressiveness ratio
        expressiveness = rank_dot / max_possible_rank
        
        results.append({
            'dim': d,
            'rank': rank_dot,
            'max_rank': max_possible_rank,
            'expressiveness': expressiveness
        })
        
        print(f"Embedding Dim {d:3d}: Rank = {rank_dot:3d}/{max_possible_rank} ({expressiveness:.3f})")
    
    return results

def demonstrate_full_rank_approximation():
    """
    Show how MoL can approximate full-rank matrices
    """
    print("\n🧠 MoL Full-Rank Approximation")
    print("=" * 50)
    
    # Create a target full-rank similarity matrix
    Q, X = 20, 30
    target_matrix = np.random.randn(Q, X)
    target_rank = np.linalg.matrix_rank(target_matrix)
    
    print(f"Target Matrix: {Q}×{X}, Rank = {target_rank}")
    
    # Approximate with different numbers of components
    component_counts = [1, 2, 4, 8, 16]
    embedding_dim = 8
    
    approximation_errors = []
    
    for P in component_counts:
        # Generate MoL components
        mol_approx = np.zeros((Q, X))
        
        for p in range(P):
            # Random component embeddings
            f_p = np.random.randn(Q, embedding_dim)
            g_p = np.random.randn(X, embedding_dim)
            
            # Random gating weights (simplified)
            weight = 1.0 / P
            
            # Component similarity
            component_sim = f_p @ g_p.T
            mol_approx += weight * component_sim
        
        # Compute approximation error
        error = np.linalg.norm(target_matrix - mol_approx, 'fro')
        approximation_errors.append(error)
        
        mol_rank = np.linalg.matrix_rank(mol_approx)
        print(f"P={P:2d}: Rank = {mol_rank:2d}, Error = {error:.3f}")
    
    return component_counts, approximation_errors

# Run analysis
rank_results = analyze_matrix_rank_expressiveness()
components, errors = demonstrate_full_rank_approximation()

## 🎯 Phần 2: Universal Approximation Theorem cho MoL

### 📖 Theorem (Informal):

**Mixture-of-Logits Universal Approximation**:
- Với đủ số components P, MoL có thể approximate bất kỳ similarity function nào
- Đặc biệt, MoL có thể tạo ra matrices với rank cao hơn embedding dimension
- Điều này giải thích tại sao MoL outperform dot products

### 🔬 Proof Sketch:
1. Mỗi component tạo ra một rank-r matrix (r ≤ embedding_dim)
2. Weighted combination có thể tăng tổng rank lên P × r
3. Với P đủ lớn → có thể đạt full rank

In [None]:
class UniversalApproximationDemo:
    """
    Demonstrate Universal Approximation property of MoL
    """
    
    def __init__(self, query_dim: int = 32, item_dim: int = 32, embedding_dim: int = 16):
        self.query_dim = query_dim
        self.item_dim = item_dim
        self.embedding_dim = embedding_dim
    
    def create_target_similarity_function(self, complexity: str = 'nonlinear'):
        """
        Create various target similarity functions to approximate
        """
        Q = torch.randn(self.query_dim, 64)  # Query features
        X = torch.randn(self.item_dim, 64)   # Item features
        
        if complexity == 'linear':
            # Linear similarity (dot product baseline)
            W = torch.randn(64, 64)
            S = Q @ W @ X.T
        
        elif complexity == 'quadratic':
            # Quadratic similarity
            S = torch.zeros(self.query_dim, self.item_dim)
            for i in range(self.query_dim):
                for j in range(self.item_dim):
                    S[i, j] = torch.sum(Q[i] * X[j])**2
        
        elif complexity == 'nonlinear':
            # Complex nonlinear similarity
            S = torch.zeros(self.query_dim, self.item_dim)
            for i in range(self.query_dim):
                for j in range(self.item_dim):
                    dot_prod = torch.dot(Q[i], X[j])
                    S[i, j] = torch.tanh(dot_prod) + 0.5 * torch.sin(dot_prod)
        
        return S, Q, X
    
    def approximate_with_mol(self, target_S: torch.Tensor, Q: torch.Tensor, X: torch.Tensor, 
                           num_components: int = 8, num_iterations: int = 1000):
        """
        Approximate target similarity using MoL
        """
        # Initialize MoL components
        query_embeddings = []
        item_embeddings = []
        gating_weights = []
        
        for p in range(num_components):
            # Component embeddings (learnable)
            f_p = nn.Parameter(torch.randn(Q.size(0), self.embedding_dim))
            g_p = nn.Parameter(torch.randn(X.size(0), self.embedding_dim))
            
            query_embeddings.append(f_p)
            item_embeddings.append(g_p)
            
            # Gating weight (learnable)
            w_p = nn.Parameter(torch.tensor(1.0 / num_components))
            gating_weights.append(w_p)
        
        # Optimizer
        all_params = query_embeddings + item_embeddings + gating_weights
        optimizer = torch.optim.Adam(all_params, lr=0.01)
        
        losses = []
        
        for iteration in range(num_iterations):
            optimizer.zero_grad()
            
            # Compute MoL similarity
            mol_similarity = torch.zeros_like(target_S)
            
            # Ensure gating weights sum to 1
            weights = F.softmax(torch.stack(gating_weights), dim=0)
            
            for p in range(num_components):
                # Normalize embeddings
                f_p_norm = F.normalize(query_embeddings[p], dim=1)
                g_p_norm = F.normalize(item_embeddings[p], dim=1)
                
                # Component similarity
                component_sim = torch.mm(f_p_norm, g_p_norm.t())
                
                # Weighted sum
                mol_similarity += weights[p] * component_sim
            
            # Loss (MSE)
            loss = F.mse_loss(mol_similarity, target_S)
            
            loss.backward()
            optimizer.step()
            
            losses.append(loss.item())
            
            if iteration % 200 == 0:
                print(f"Iteration {iteration:4d}: Loss = {loss.item():.6f}")
        
        return mol_similarity.detach(), losses, weights.detach()

# Demonstration
print("🎯 Universal Approximation Demonstration")
print("=" * 50)

demo = UniversalApproximationDemo(query_dim=20, item_dim=25, embedding_dim=8)

# Test different complexity levels
complexities = ['linear', 'quadratic', 'nonlinear']
approximation_results = {}

for complexity in complexities:
    print(f"\n📊 Approximating {complexity} similarity...")
    
    target_S, Q, X = demo.create_target_similarity_function(complexity)
    target_rank = torch.linalg.matrix_rank(target_S).item()
    
    print(f"Target matrix rank: {target_rank}")
    
    # Approximate with MoL
    mol_S, losses, weights = demo.approximate_with_mol(target_S, Q, X, num_components=12, num_iterations=800)
    
    mol_rank = torch.linalg.matrix_rank(mol_S).item()
    final_error = F.mse_loss(mol_S, target_S).item()
    
    print(f"MoL approximation rank: {mol_rank}")
    print(f"Final approximation error: {final_error:.6f}")
    
    approximation_results[complexity] = {
        'target_rank': target_rank,
        'mol_rank': mol_rank,
        'error': final_error,
        'losses': losses,
        'weights': weights
    }

## 📊 Phần 3: So sánh Dot Product vs MoL

In [None]:
def compare_expressiveness_dot_vs_mol():
    """
    Direct comparison between dot product and MoL expressiveness
    """
    print("\n⚔️ Dot Product vs MoL Expressiveness")
    print("=" * 50)
    
    # Parameters
    num_queries, num_items = 50, 80
    feature_dim = 128
    embedding_dims = [8, 16, 32, 64]
    
    # Generate random query and item features
    queries = torch.randn(num_queries, feature_dim)
    items = torch.randn(num_items, feature_dim)
    
    results = []
    
    for embed_dim in embedding_dims:
        print(f"\n📏 Embedding Dimension: {embed_dim}")
        
        # Dot Product Model
        dot_query_proj = nn.Linear(feature_dim, embed_dim)
        dot_item_proj = nn.Linear(feature_dim, embed_dim)
        
        with torch.no_grad():
            q_emb = F.normalize(dot_query_proj(queries), dim=1)
            i_emb = F.normalize(dot_item_proj(items), dim=1)
            dot_similarity = torch.mm(q_emb, i_emb.t())
            dot_rank = torch.linalg.matrix_rank(dot_similarity).item()
        
        # MoL Model with different component counts
        mol_ranks = []
        component_counts = [2, 4, 8, 16]
        
        for num_comp in component_counts:
            mol_similarity = torch.zeros(num_queries, num_items)
            
            for p in range(num_comp):
                # Random component projections
                f_proj = nn.Linear(feature_dim, embed_dim)
                g_proj = nn.Linear(feature_dim, embed_dim)
                
                with torch.no_grad():
                    f_p = F.normalize(f_proj(queries), dim=1)
                    g_p = F.normalize(g_proj(items), dim=1)
                    component_sim = torch.mm(f_p, g_p.t())
                    mol_similarity += (1.0 / num_comp) * component_sim
            
            mol_rank = torch.linalg.matrix_rank(mol_similarity).item()
            mol_ranks.append(mol_rank)
        
        max_possible_rank = min(num_queries, num_items)
        
        print(f"   Dot Product Rank: {dot_rank:2d} / {max_possible_rank} ({dot_rank/max_possible_rank:.3f})")
        print(f"   MoL Ranks:")
        for i, (nc, mr) in enumerate(zip(component_counts, mol_ranks)):
            print(f"     {nc:2d} components: {mr:2d} / {max_possible_rank} ({mr/max_possible_rank:.3f})")
        
        results.append({
            'embed_dim': embed_dim,
            'dot_rank': dot_rank,
            'mol_ranks': mol_ranks,
            'max_rank': max_possible_rank
        })
    
    return results

# Run comparison
comparison_results = compare_expressiveness_dot_vs_mol()

## 📈 Visualization và Phân tích

In [None]:
# Create comprehensive visualization
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# 1. Matrix Rank vs Embedding Dimension
dims = [r['dim'] for r in rank_results]
ranks = [r['rank'] for r in rank_results]
max_ranks = [r['max_rank'] for r in rank_results]

axes[0, 0].plot(dims, ranks, 'o-', label='Achieved Rank', linewidth=2, markersize=8)
axes[0, 0].axhline(y=max_ranks[0], color='red', linestyle='--', label='Max Possible Rank', alpha=0.7)
axes[0, 0].set_xlabel('Embedding Dimension')
axes[0, 0].set_ylabel('Matrix Rank')
axes[0, 0].set_title('Dot Product: Rank vs Embedding Dimension')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# 2. MoL Approximation Error vs Components
axes[0, 1].plot(components, errors, 'go-', linewidth=2, markersize=8)
axes[0, 1].set_xlabel('Number of Components')
axes[0, 1].set_ylabel('Approximation Error')
axes[0, 1].set_title('MoL: Approximation Error vs Components')
axes[0, 1].set_yscale('log')
axes[0, 1].grid(True, alpha=0.3)

# 3. Universal Approximation Results
complexities = list(approximation_results.keys())
target_ranks = [approximation_results[c]['target_rank'] for c in complexities]
mol_ranks = [approximation_results[c]['mol_rank'] for c in complexities]
errors = [approximation_results[c]['error'] for c in complexities]

x = np.arange(len(complexities))
width = 0.35

axes[0, 2].bar(x - width/2, target_ranks, width, label='Target Rank', alpha=0.8)
axes[0, 2].bar(x + width/2, mol_ranks, width, label='MoL Approximation Rank', alpha=0.8)
axes[0, 2].set_xlabel('Similarity Complexity')
axes[0, 2].set_ylabel('Matrix Rank')
axes[0, 2].set_title('Universal Approximation: Rank Achievement')
axes[0, 2].set_xticks(x)
axes[0, 2].set_xticklabels(complexities)
axes[0, 2].legend()
axes[0, 2].grid(True, alpha=0.3)

# 4. Training Loss for Universal Approximation
for i, complexity in enumerate(complexities):
    losses = approximation_results[complexity]['losses']
    axes[1, 0].plot(losses, label=f'{complexity.capitalize()}', linewidth=2)

axes[1, 0].set_xlabel('Iteration')
axes[1, 0].set_ylabel('MSE Loss')
axes[1, 0].set_title('MoL Training: Convergence Analysis')
axes[1, 0].set_yscale('log')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# 5. Expressiveness Comparison
embed_dims = [r['embed_dim'] for r in comparison_results]
dot_expressiveness = [r['dot_rank'] / r['max_rank'] for r in comparison_results]
mol_expressiveness = [[mr / r['max_rank'] for mr in r['mol_ranks']] for r in comparison_results]

axes[1, 1].plot(embed_dims, dot_expressiveness, 'ro-', label='Dot Product', linewidth=2, markersize=8)

# Plot MoL with different component counts
component_counts = [2, 4, 8, 16]
colors = ['green', 'blue', 'purple', 'orange']
for i, (nc, color) in enumerate(zip(component_counts, colors)):
    mol_expr = [me[i] for me in mol_expressiveness]
    axes[1, 1].plot(embed_dims, mol_expr, f'{color[0]}o-', label=f'MoL-{nc}', linewidth=2, markersize=6)

axes[1, 1].set_xlabel('Embedding Dimension')
axes[1, 1].set_ylabel('Expressiveness (Rank Ratio)')
axes[1, 1].set_title('Expressiveness: Dot Product vs MoL')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

# 6. Component Weight Distribution
if 'nonlinear' in approximation_results:
    weights = approximation_results['nonlinear']['weights'].numpy()
    axes[1, 2].bar(range(len(weights)), weights, alpha=0.7, color='skyblue')
    axes[1, 2].set_xlabel('Component Index')
    axes[1, 2].set_ylabel('Weight')
    axes[1, 2].set_title('MoL Component Weight Distribution')
    axes[1, 2].grid(True, alpha=0.3)
else:
    axes[1, 2].text(0.5, 0.5, 'Component weights\nnot available', 
                   ha='center', va='center', transform=axes[1, 2].transAxes)
    axes[1, 2].set_title('Component Analysis')

plt.tight_layout()
plt.show()

print("\n📊 Visualization completed")

## 🎓 Key Insights và Kết luận

### 🔍 Quan sát từ Thí nghiệm:

1. **Dot Product Limitation**: 
   - Matrix rank bị giới hạn bởi embedding dimension
   - Khi d << min(Q, X), expressiveness bị nghiêm trọng hạn chế
   - Không thể tăng expressiveness bằng cách tăng d (memory & overfitting)

2. **MoL Advantages**:
   - Có thể đạt rank cao hơn embedding dimension
   - Expressiveness tăng theo số components P
   - Universal approximation property được verify empirically

3. **Practical Implications**:
   - MoL phù hợp cho complex similarity patterns
   - Trade-off giữa accuracy và computational cost
   - Load balancing quan trọng để tránh component collapse

### 📖 Theoretical Foundation:

**Theorem**: MoL with P components và embedding dimension d có thể approximate matrices với rank lên đến P × d (trong điều kiện lý tưởng).

**Proof Idea**: 
- Mỗi component ∑ᵢ wᵢ fᵢ(q) gᵢ(x)ᵀ có rank ≤ d
- Linear combination của P components có rank ≤ P × d  
- Với gating weights phù hợp, có thể achieve rank cao

### 🚀 Practical Applications:

1. **Recommendation Systems**: Model complex user-item interactions
2. **Information Retrieval**: Capture nuanced query-document relationships  
3. **Question Answering**: Rich semantic matching beyond dot products
4. **Multi-modal Retrieval**: Cross-modal similarity learning

### 🎯 Next Steps:

- **Load Balancing**: Ensure component utilization
- **Efficient Algorithms**: Fast top-K retrieval
- **GPU Optimization**: Hardware-friendly implementations
- **Real-world Evaluation**: Large-scale datasets