# Retrieval with Learned Similarities - Triển khai và Nghiên cứu

## 📄 Thông tin Paper
- **Title**: Retrieval with Learned Similarities
- **Authors**: Bailu Ding (Microsoft Research), Jiaqi Zhai (Meta)
- **Conference**: WWW 2025
- **arXiv**: 2407.15462v4
- **GitHub**: https://github.com/bailuding/rails

## 🎯 Tóm tắt Paper

Paper này giải quyết thách thức **retrieval hiệu quả với learned similarity functions** - một vấn đề quan trọng trong hệ thống recommendation, search và NLP hiện đại.

### 🔑 Vấn đề chính:
- Các thuật toán retrieval state-of-the-art đã chuyển từ dot products đơn giản sang **learned similarities** phức tạp
- Learned similarities bao gồm: multiple query embeddings, neural networks, beam search, hybrid approaches
- **Thiếu giải pháp hiệu quả** cho retrieval với learned similarities này

### 💡 Giải pháp - Mixture-of-Logits (MoL):
1. **Universal Approximator**: MoL có thể biểu diễn bất kỳ similarity function nào
2. **Efficient Retrieval**: Đề xuất algorithms để retrieve top-k với error bounds chặt
3. **RAILS Framework**: Retrieval with Learned Similarities trên GPU

### 📊 Kết quả:
- **Performance**: 20-30% cải thiện Hit Rate@50-400 trên corpus hàng trăm triệu đến tỷ items
- **Speed**: Nhanh hơn 66x so với exact algorithms với >99% recall rate
- **Applicability**: Hoạt động tốt trên recommendation systems và question answering


## 🛠️ Cài đặt Môi trường

In [None]:
# Core dependencies
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Dict, Optional, Union
import math
import matplotlib.pyplot as plt
import seaborn as sns
from dataclasses import dataclass
from abc import ABC, abstractmethod
import time
import warnings
warnings.filterwarnings('ignore')

# LangChain components for RAG integration
try:
    from langchain_core.embeddings import Embeddings
    from langchain_core.vectorstores import VectorStore
    from langchain_community.vectorstores import FAISS
    from langchain_openai import OpenAIEmbeddings
    LANGCHAIN_AVAILABLE = True
    print("✅ LangChain components loaded successfully")
except ImportError:
    LANGCHAIN_AVAILABLE = False
    print("⚠️ LangChain not available - using pure PyTorch implementation")

# DeepEval for evaluation
try:
    import deepeval
    from deepeval.metrics import FaithfulnessMetric, AnswerRelevancyMetric
    DEEPEVAL_AVAILABLE = True
    print("✅ DeepEval loaded for evaluation")
except ImportError:
    DEEPEVAL_AVAILABLE = False
    print("⚠️ DeepEval not available - using custom metrics")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🔧 Using device: {device}")

## 📊 Chuẩn bị Dữ liệu Mock

Tạo dữ liệu synthetic để demo các khái niệm trong paper:

In [None]:
@dataclass
class DatasetConfig:
    """Configuration for synthetic dataset generation"""
    num_queries: int = 1000
    num_items: int = 10000
    embedding_dim: int = 128
    num_categories: int = 10
    noise_level: float = 0.1

def generate_synthetic_data(config: DatasetConfig) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Generate synthetic query-item data for retrieval experiments
    
    Returns:
        queries: [num_queries, embedding_dim]
        items: [num_items, embedding_dim] 
        query_categories: [num_queries]
        item_categories: [num_items]
    """
    torch.manual_seed(42)
    
    # Generate category centers
    category_centers = torch.randn(config.num_categories, config.embedding_dim)
    category_centers = F.normalize(category_centers, dim=1)
    
    # Generate queries
    query_categories = torch.randint(0, config.num_categories, (config.num_queries,))
    queries = category_centers[query_categories] + config.noise_level * torch.randn(config.num_queries, config.embedding_dim)
    queries = F.normalize(queries, dim=1)
    
    # Generate items
    item_categories = torch.randint(0, config.num_categories, (config.num_items,))
    items = category_centers[item_categories] + config.noise_level * torch.randn(config.num_items, config.embedding_dim)
    items = F.normalize(items, dim=1)
    
    return queries, items, query_categories, item_categories

# Generate data
config = DatasetConfig()
queries, items, query_cats, item_cats = generate_synthetic_data(config)

print(f"📊 Generated data:")
print(f"   Queries: {queries.shape}")
print(f"   Items: {items.shape}")
print(f"   Categories: {config.num_categories}")

# Move to device
queries = queries.to(device)
items = items.to(device)
query_cats = query_cats.to(device)
item_cats = item_cats.to(device)

## 🧠 Triển khai Mixture-of-Logits (MoL)

### 📖 Lý thuyết MoL (từ Paper - Section 2)

MoL được định nghĩa như sau:

$$\phi(q,x) = \sum_{p=1}^{P} \pi_p(q,x) \langle f_p(q), g_p(x) \rangle$$

Trong đó:
- $f_p(q), g_p(x) \in \mathbb{R}^{d_P}$: component-level embeddings
- $\pi_p(q,x) \in [0,1]$: adaptive gating weights
- $P$: số lượng components

**Tại sao MoL hiệu quả?**
1. **Universal Approximator**: Có thể biểu diễn ma trận $p(x|q)$ với rank tùy ý
2. **GPU-friendly**: Tận dụng được arithmetic intensity cao của GPU
3. **Flexible**: Có thể mô phỏng nhiều learned similarity functions khác nhau

In [None]:
class MixtureOfLogitsModule(nn.Module):
    """
    Mixture-of-Logits (MoL) implementation as described in the paper
    
    Reference: Section 2 - Mixture of Logits
    Formula: φ(q,x) = Σ π_p(q,x) * <f_p(q), g_p(x)>
    """
    
    def __init__(self, 
                 input_dim: int,
                 num_components: int = 8,
                 component_dim: int = 64,
                 use_outer_product: bool = True,
                 dropout: float = 0.1):
        super().__init__()
        
        self.input_dim = input_dim
        self.num_components = num_components
        self.component_dim = component_dim
        self.use_outer_product = use_outer_product
        
        if use_outer_product:
            # Batched outer product form (Section 2, Equation after (1))
            self.P_q = int(math.sqrt(num_components))
            self.P_x = num_components // self.P_q
            
            # Query-side embeddings
            self.query_embeddings = nn.ModuleList([
                nn.Linear(input_dim, component_dim) for _ in range(self.P_q)
            ])
            
            # Item-side embeddings
            self.item_embeddings = nn.ModuleList([
                nn.Linear(input_dim, component_dim) for _ in range(self.P_x)
            ])
            
            # Gating network for outer product
            self.gating_network = nn.Sequential(
                nn.Linear(input_dim * 2, 128),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(128, self.P_q * self.P_x),
                nn.Softmax(dim=-1)
            )
        else:
            # Standard MoL form (Equation 1)
            self.query_embeddings = nn.ModuleList([
                nn.Linear(input_dim, component_dim) for _ in range(num_components)
            ])
            
            self.item_embeddings = nn.ModuleList([
                nn.Linear(input_dim, component_dim) for _ in range(num_components)
            ])
            
            # Gating network
            self.gating_network = nn.Sequential(
                nn.Linear(input_dim * 2, 128),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(128, num_components),
                nn.Softmax(dim=-1)
            )
    
    def forward(self, queries: torch.Tensor, items: torch.Tensor) -> torch.Tensor:
        """
        Compute MoL similarity between queries and items
        
        Args:
            queries: [batch_q, input_dim]
            items: [batch_x, input_dim]
            
        Returns:
            similarities: [batch_q, batch_x]
        """
        batch_q, batch_x = queries.size(0), items.size(0)
        
        if self.use_outer_product:
            return self._forward_outer_product(queries, items)
        else:
            return self._forward_standard(queries, items)
    
    def _forward_standard(self, queries: torch.Tensor, items: torch.Tensor) -> torch.Tensor:
        """Standard MoL computation (Equation 1)"""
        batch_q, batch_x = queries.size(0), items.size(0)
        
        # Compute component embeddings
        q_components = [F.normalize(emb(queries), dim=-1) for emb in self.query_embeddings]
        x_components = [F.normalize(emb(items), dim=-1) for emb in self.item_embeddings]
        
        # Compute similarities for all query-item pairs
        similarities = torch.zeros(batch_q, batch_x, device=queries.device)
        
        for i in range(batch_q):
            for j in range(batch_x):
                # Compute gating weights
                combined_features = torch.cat([queries[i], items[j]], dim=0)
                weights = self.gating_network(combined_features.unsqueeze(0)).squeeze(0)  # [num_components]
                
                # Compute weighted sum of component similarities
                similarity = 0.0
                for p in range(self.num_components):
                    component_sim = torch.dot(q_components[p][i], x_components[p][j])
                    similarity += weights[p] * component_sim
                
                similarities[i, j] = similarity
        
        return similarities
    
    def _forward_outer_product(self, queries: torch.Tensor, items: torch.Tensor) -> torch.Tensor:
        """Batched outer product form for efficiency"""
        batch_q, batch_x = queries.size(0), items.size(0)
        
        # Compute component embeddings
        q_components = [F.normalize(emb(queries), dim=-1) for emb in self.query_embeddings]  # P_q x [batch_q, component_dim]
        x_components = [F.normalize(emb(items), dim=-1) for emb in self.item_embeddings]    # P_x x [batch_x, component_dim]
        
        # Efficient batched computation
        similarities = torch.zeros(batch_q, batch_x, device=queries.device)
        
        # Pre-compute all combinations
        for i in range(batch_q):
            for j in range(batch_x):
                # Compute gating weights
                combined_features = torch.cat([queries[i], items[j]], dim=0)
                weights = self.gating_network(combined_features.unsqueeze(0)).squeeze(0)  # [P_q * P_x]
                weights = weights.view(self.P_q, self.P_x)
                
                # Compute similarity using outer product structure
                similarity = 0.0
                for p_q in range(self.P_q):
                    for p_x in range(self.P_x):
                        component_sim = torch.dot(q_components[p_q][i], x_components[p_x][j])
                        similarity += weights[p_q, p_x] * component_sim
                
                similarities[i, j] = similarity
        
        return similarities

print("🧠 MoL Module implemented successfully")

## 🎯 Load Balancing Loss (Phần quan trọng từ Paper)

### 📖 Lý thuyết (Section 2.2)

Paper đề xuất **mutual information-based load balancing loss** để cải thiện conditional computation trong MoL:

$$L_{bal} = -I(Z; G)$$

Trong đó:
- $Z$: latent representations
- $G$: gating decisions
- $I(Z; G)$: mutual information

**Mục đích**: Đảm bảo các components được sử dụng đồng đều, tránh collapse.

In [None]:
class LoadBalancingLoss(nn.Module):
    """
    Mutual Information-based Load Balancing Loss
    
    Reference: Section 2.2 - Load balancing loss to improve conditional computations
    """
    
    def __init__(self, num_components: int, lambda_bal: float = 0.01):
        super().__init__()
        self.num_components = num_components
        self.lambda_bal = lambda_bal
    
    def forward(self, gating_weights: torch.Tensor) -> torch.Tensor:
        """
        Compute load balancing loss
        
        Args:
            gating_weights: [batch_size, num_components] or [batch_size, P_q * P_x]
            
        Returns:
            loss: scalar tensor
        """
        batch_size = gating_weights.size(0)
        
        if gating_weights.dim() > 2:
            gating_weights = gating_weights.view(batch_size, -1)
        
        # Compute component usage statistics
        component_usage = gating_weights.mean(dim=0)  # [num_components]
        
        # Encourage uniform distribution (entropy maximization)
        uniform_target = torch.ones_like(component_usage) / self.num_components
        
        # KL divergence from uniform distribution
        kl_loss = F.kl_div(
            component_usage.log() + 1e-8, 
            uniform_target, 
            reduction='sum'
        )
        
        # Additional entropy term for better balancing
        entropy_loss = -torch.sum(component_usage * torch.log(component_usage + 1e-8))
        
        return self.lambda_bal * (kl_loss - entropy_loss)

print("⚖️ Load Balancing Loss implemented")

## 🚀 RAILS Framework - Retrieval with Learned Similarities

### 📖 Thuật toán Top-K Retrieval (Section 3)

Paper đề xuất hai thuật toán:
1. **Exact Top-K**: Đảm bảo kết quả chính xác
2. **Approximate Top-K**: Nhanh hơn với error bounds

**Key Insight**: Tận dụng structure của MoL để optimize retrieval efficiency.

In [None]:
class RAILSRetriever:
    """
    Retrieval with Learned Similarities (RAILS) Framework
    
    Reference: Section 3 - Efficient retrieval techniques with MoL
    """
    
    def __init__(self, mol_model: MixtureOfLogitsModule):
        self.mol_model = mol_model
        self.item_embeddings = None
        self.item_ids = None
    
    def index_items(self, items: torch.Tensor, item_ids: Optional[List] = None):
        """
        Index items for efficient retrieval
        
        Args:
            items: [num_items, input_dim]
            item_ids: Optional list of item identifiers
        """
        self.item_embeddings = items
        self.item_ids = item_ids or list(range(len(items)))
        
        print(f"📚 Indexed {len(items)} items")
    
    def exact_top_k_retrieval(self, 
                             queries: torch.Tensor, 
                             k: int = 10) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Exact top-k retrieval using MoL
        
        Args:
            queries: [num_queries, input_dim]
            k: number of top items to retrieve
            
        Returns:
            scores: [num_queries, k] - similarity scores
            indices: [num_queries, k] - item indices
        """
        if self.item_embeddings is None:
            raise ValueError("Items not indexed. Call index_items() first.")
        
        with torch.no_grad():
            # Compute similarities between all queries and items
            similarities = self.mol_model(queries, self.item_embeddings)
            
            # Get top-k for each query
            top_scores, top_indices = torch.topk(similarities, k, dim=1, largest=True)
        
        return top_scores, top_indices
    
    def approximate_top_k_retrieval(self, 
                                   queries: torch.Tensor, 
                                   k: int = 10,
                                   candidate_ratio: float = 0.1) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Approximate top-k retrieval with error bounds
        
        Args:
            queries: [num_queries, input_dim]
            k: number of top items to retrieve
            candidate_ratio: ratio of items to consider as candidates
            
        Returns:
            scores: [num_queries, k] - similarity scores
            indices: [num_queries, k] - item indices
        """
        if self.item_embeddings is None:
            raise ValueError("Items not indexed. Call index_items() first.")
        
        num_candidates = max(k * 2, int(len(self.item_embeddings) * candidate_ratio))
        
        with torch.no_grad():
            # Step 1: Fast candidate selection using single component
            # Use first component as approximation
            if hasattr(self.mol_model, 'query_embeddings'):
                q_emb = self.mol_model.query_embeddings[0](queries)
                i_emb = self.mol_model.item_embeddings[0](self.item_embeddings)
                
                q_emb = F.normalize(q_emb, dim=-1)
                i_emb = F.normalize(i_emb, dim=-1)
                
                # Fast similarity computation
                fast_sim = torch.mm(q_emb, i_emb.t())
                
                # Select candidates
                _, candidate_indices = torch.topk(fast_sim, num_candidates, dim=1)
            else:
                # Fallback: random sampling
                num_items = len(self.item_embeddings)
                candidate_indices = torch.randint(0, num_items, (len(queries), num_candidates))
            
            # Step 2: Accurate re-ranking using full MoL
            top_scores_list = []
            top_indices_list = []
            
            for i, query in enumerate(queries):
                candidates = self.item_embeddings[candidate_indices[i]]
                query_batch = query.unsqueeze(0).expand(len(candidates), -1)
                
                # Compute accurate similarities
                accurate_sim = self.mol_model(query_batch, candidates).diag()
                
                # Get top-k from candidates
                top_k_scores, top_k_idx = torch.topk(accurate_sim, min(k, len(accurate_sim)))
                
                # Map back to original indices
                original_indices = candidate_indices[i][top_k_idx]
                
                top_scores_list.append(top_k_scores)
                top_indices_list.append(original_indices)
            
            # Pad sequences to same length
            max_len = max(len(scores) for scores in top_scores_list)
            
            padded_scores = torch.zeros(len(queries), max_len)
            padded_indices = torch.zeros(len(queries), max_len, dtype=torch.long)
            
            for i, (scores, indices) in enumerate(zip(top_scores_list, top_indices_list)):
                padded_scores[i, :len(scores)] = scores
                padded_indices[i, :len(indices)] = indices
        
        return padded_scores[:, :k], padded_indices[:, :k]

print("🚀 RAILS Retriever implemented")

## 🎓 Training Loop với Load Balancing

In [None]:
def create_relevance_labels(query_cats: torch.Tensor, item_cats: torch.Tensor) -> torch.Tensor:
    """
    Create relevance labels based on category matching
    
    Args:
        query_cats: [num_queries] - query categories
        item_cats: [num_items] - item categories
        
    Returns:
        labels: [num_queries, num_items] - binary relevance
    """
    num_queries, num_items = len(query_cats), len(item_cats)
    labels = torch.zeros(num_queries, num_items)
    
    for i in range(num_queries):
        labels[i] = (item_cats == query_cats[i]).float()
    
    return labels

def train_mol_model(mol_model: MixtureOfLogitsModule,
                   queries: torch.Tensor,
                   items: torch.Tensor, 
                   query_cats: torch.Tensor,
                   item_cats: torch.Tensor,
                   num_epochs: int = 50,
                   batch_size: int = 32,
                   lr: float = 0.001) -> Dict:
    """
    Train MoL model with load balancing loss
    """
    optimizer = torch.optim.AdamW(mol_model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    
    load_balancing_loss = LoadBalancingLoss(
        mol_model.num_components if not mol_model.use_outer_product 
        else mol_model.P_q * mol_model.P_x
    )
    
    # Create training data
    relevance_labels = create_relevance_labels(query_cats, item_cats).to(device)
    
    train_losses = []
    bal_losses = []
    
    mol_model.train()
    
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        epoch_bal_loss = 0.0
        num_batches = 0
        
        # Sample training pairs
        for _ in range(max(1, len(queries) // batch_size)):
            # Sample query batch
            q_indices = torch.randint(0, len(queries), (batch_size,))
            i_indices = torch.randint(0, len(items), (batch_size,))
            
            query_batch = queries[q_indices]
            item_batch = items[i_indices]
            labels_batch = relevance_labels[q_indices][:, i_indices]
            
            optimizer.zero_grad()
            
            # Forward pass
            similarities = mol_model(query_batch, item_batch)
            
            # Main loss (binary cross entropy)
            main_loss = F.binary_cross_entropy_with_logits(
                similarities.diag(), labels_batch.diag()
            )
            
            # Extract gating weights for load balancing
            # This is a simplified version - in practice, we'd need to extract weights from forward pass
            with torch.no_grad():
                dummy_features = torch.cat([query_batch[0], item_batch[0]], dim=0)
                gating_weights = mol_model.gating_network(dummy_features.unsqueeze(0))
            
            bal_loss = load_balancing_loss(gating_weights)
            
            total_loss = main_loss + bal_loss
            
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(mol_model.parameters(), 1.0)
            optimizer.step()
            
            epoch_loss += main_loss.item()
            epoch_bal_loss += bal_loss.item()
            num_batches += 1
        
        scheduler.step()
        
        avg_loss = epoch_loss / num_batches
        avg_bal_loss = epoch_bal_loss / num_batches
        
        train_losses.append(avg_loss)
        bal_losses.append(avg_bal_loss)
        
        if epoch % 10 == 0:
            print(f"Epoch {epoch:3d}: Loss = {avg_loss:.4f}, Bal Loss = {avg_bal_loss:.4f}")
    
    return {
        'train_losses': train_losses,
        'bal_losses': bal_losses
    }

print("🎓 Training functions defined")

## 🧪 Thí nghiệm và So sánh

So sánh MoL với baseline approaches:

In [None]:
# Initialize models
print("🏗️ Initializing models...")

# MoL model
mol_model = MixtureOfLogitsModule(
    input_dim=config.embedding_dim,
    num_components=8,
    component_dim=32,
    use_outer_product=True
).to(device)

# Simple dot product baseline
class DotProductModel(nn.Module):
    def __init__(self, input_dim: int, embed_dim: int = 128):
        super().__init__()
        self.query_proj = nn.Linear(input_dim, embed_dim)
        self.item_proj = nn.Linear(input_dim, embed_dim)
    
    def forward(self, queries: torch.Tensor, items: torch.Tensor) -> torch.Tensor:
        q_emb = F.normalize(self.query_proj(queries), dim=-1)
        i_emb = F.normalize(self.item_proj(items), dim=-1)
        return torch.mm(q_emb, i_emb.t())

dot_model = DotProductModel(config.embedding_dim).to(device)

print(f"📊 Model parameters:")
print(f"   MoL: {sum(p.numel() for p in mol_model.parameters()):,}")
print(f"   Dot Product: {sum(p.numel() for p in dot_model.parameters()):,}")

In [None]:
# Train MoL model
print("🎓 Training MoL model...")
training_history = train_mol_model(
    mol_model, queries, items, query_cats, item_cats,
    num_epochs=30, batch_size=16, lr=0.001
)

# Train baseline
print("\n🎓 Training Dot Product baseline...")
optimizer_dot = torch.optim.AdamW(dot_model.parameters(), lr=0.001)
relevance_labels = create_relevance_labels(query_cats, item_cats).to(device)

dot_model.train()
for epoch in range(20):
    total_loss = 0
    for _ in range(10):
        q_idx = torch.randint(0, len(queries), (16,))
        i_idx = torch.randint(0, len(items), (16,))
        
        similarities = dot_model(queries[q_idx], items[i_idx])
        labels = relevance_labels[q_idx][:, i_idx]
        
        loss = F.binary_cross_entropy_with_logits(similarities.diag(), labels.diag())
        
        optimizer_dot.zero_grad()
        loss.backward()
        optimizer_dot.step()
        
        total_loss += loss.item()
    
    if epoch % 5 == 0:
        print(f"Epoch {epoch:2d}: Loss = {total_loss/10:.4f}")

print("✅ Training completed")

## 📊 Evaluation với Custom Metrics

Vì DeepEval có thể không available, chúng ta implement custom metrics theo paper:

In [None]:
def compute_hit_rate_at_k(predictions: torch.Tensor, 
                         labels: torch.Tensor, 
                         k_values: List[int] = [1, 5, 10, 50]) -> Dict[str, float]:
    """
    Compute Hit Rate@K as used in the paper
    
    Args:
        predictions: [num_queries, num_items] - similarity scores
        labels: [num_queries, num_items] - binary relevance
        k_values: list of k values to evaluate
        
    Returns:
        hit_rates: dict mapping k to hit rate
    """
    hit_rates = {}
    
    for k in k_values:
        # Get top-k predictions for each query
        _, top_k_indices = torch.topk(predictions, k, dim=1)
        
        # Check if any top-k item is relevant
        hits = 0
        for i in range(len(predictions)):
            relevant_items = labels[i].nonzero().squeeze(-1)
            if len(relevant_items) > 0:
                top_k_items = set(top_k_indices[i].cpu().numpy())
                relevant_set = set(relevant_items.cpu().numpy())
                if top_k_items.intersection(relevant_set):
                    hits += 1
        
        hit_rates[f'HR@{k}'] = hits / len(predictions)
    
    return hit_rates

def compute_ndcg_at_k(predictions: torch.Tensor, 
                     labels: torch.Tensor, 
                     k_values: List[int] = [5, 10]) -> Dict[str, float]:
    """
    Compute NDCG@K
    """
    ndcg_scores = {}
    
    for k in k_values:
        _, top_k_indices = torch.topk(predictions, k, dim=1)
        
        total_ndcg = 0.0
        for i in range(len(predictions)):
            # Get relevance scores for top-k items
            relevance_scores = labels[i][top_k_indices[i]].cpu().numpy()
            
            # Compute DCG
            dcg = 0.0
            for j, rel in enumerate(relevance_scores):
                dcg += rel / np.log2(j + 2)  # j+2 because log2(1) = 0
            
            # Compute IDCG (perfect ranking)
            ideal_relevance = sorted(labels[i].cpu().numpy(), reverse=True)[:k]
            idcg = 0.0
            for j, rel in enumerate(ideal_relevance):
                idcg += rel / np.log2(j + 2)
            
            # NDCG
            if idcg > 0:
                total_ndcg += dcg / idcg
        
        ndcg_scores[f'NDCG@{k}'] = total_ndcg / len(predictions)
    
    return ndcg_scores

print("📊 Custom evaluation metrics defined")

In [None]:
# Evaluation
print("🧪 Evaluating models...")

mol_model.eval()
dot_model.eval()

# Use subset for evaluation (computational efficiency)
eval_queries = queries[:100]
eval_items = items[:1000]
eval_query_cats = query_cats[:100]
eval_item_cats = item_cats[:1000]

eval_labels = create_relevance_labels(eval_query_cats, eval_item_cats).to(device)

with torch.no_grad():
    # MoL predictions
    print("   Computing MoL predictions...")
    mol_predictions = mol_model(eval_queries, eval_items)
    
    # Dot product predictions
    print("   Computing Dot Product predictions...")
    dot_predictions = dot_model(eval_queries, eval_items)

# Compute metrics
print("\n📊 Results:")
print("\n🧠 MoL Performance:")
mol_hr = compute_hit_rate_at_k(mol_predictions, eval_labels)
mol_ndcg = compute_ndcg_at_k(mol_predictions, eval_labels)
for metric, value in {**mol_hr, **mol_ndcg}.items():
    print(f"   {metric}: {value:.4f}")

print("\n⚫ Dot Product Performance:")
dot_hr = compute_hit_rate_at_k(dot_predictions, eval_labels)
dot_ndcg = compute_ndcg_at_k(dot_predictions, eval_labels)
for metric, value in {**dot_hr, **dot_ndcg}.items():
    print(f"   {metric}: {value:.4f}")

# Improvement analysis
print("\n📈 Improvements:")
for metric in mol_hr.keys():
    improvement = (mol_hr[metric] - dot_hr[metric]) / dot_hr[metric] * 100
    print(f"   {metric}: {improvement:+.1f}%")

## ⚡ Performance & Latency Analysis

Test performance benefits như trong paper (Section 4.4):

In [None]:
# Performance benchmarking
print("⚡ Performance Benchmarking...")

# Create RAILS retriever
rails_retriever = RAILSRetriever(mol_model)
rails_retriever.index_items(items)

# Test queries
test_queries = queries[:10]

# Benchmark exact retrieval
print("\n🎯 Exact Top-K Retrieval:")
start_time = time.time()
exact_scores, exact_indices = rails_retriever.exact_top_k_retrieval(test_queries, k=10)
exact_time = time.time() - start_time
print(f"   Time: {exact_time:.4f}s")
print(f"   Throughput: {len(test_queries)/exact_time:.1f} queries/sec")

# Benchmark approximate retrieval
print("\n🚀 Approximate Top-K Retrieval:")
start_time = time.time()
approx_scores, approx_indices = rails_retriever.approximate_top_k_retrieval(
    test_queries, k=10, candidate_ratio=0.1
)
approx_time = time.time() - start_time
print(f"   Time: {approx_time:.4f}s")
print(f"   Throughput: {len(test_queries)/approx_time:.1f} queries/sec")
print(f"   Speedup: {exact_time/approx_time:.1f}x")

# Compute recall between exact and approximate
total_recall = 0.0
for i in range(len(test_queries)):
    exact_set = set(exact_indices[i][:5].cpu().numpy())  # Top-5
    approx_set = set(approx_indices[i][:5].cpu().numpy())
    recall = len(exact_set.intersection(approx_set)) / len(exact_set)
    total_recall += recall

avg_recall = total_recall / len(test_queries)
print(f"   Recall@5: {avg_recall:.3f}")

# Memory usage analysis
print(f"\n💾 Memory Analysis:")
print(f"   MoL model size: {sum(p.numel() * 4 for p in mol_model.parameters()) / 1024 / 1024:.1f} MB")
print(f"   Item embeddings: {items.numel() * 4 / 1024 / 1024:.1f} MB")

## 📈 Visualization & Analysis

In [None]:
# Plotting results
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# Training loss
axes[0, 0].plot(training_history['train_losses'], label='Main Loss', color='blue')
axes[0, 0].plot(training_history['bal_losses'], label='Load Balancing Loss', color='red')
axes[0, 0].set_title('Training Losses')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Performance comparison
metrics = ['HR@1', 'HR@5', 'HR@10', 'NDCG@5', 'NDCG@10']
mol_values = [mol_hr.get(m, mol_ndcg.get(m, 0)) for m in metrics]
dot_values = [dot_hr.get(m, dot_ndcg.get(m, 0)) for m in metrics]

x = np.arange(len(metrics))
width = 0.35

axes[0, 1].bar(x - width/2, mol_values, width, label='MoL', color='green', alpha=0.7)
axes[0, 1].bar(x + width/2, dot_values, width, label='Dot Product', color='orange', alpha=0.7)
axes[0, 1].set_title('Performance Comparison')
axes[0, 1].set_ylabel('Score')
axes[0, 1].set_xticks(x)
axes[0, 1].set_xticklabels(metrics, rotation=45)
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Similarity distribution analysis
with torch.no_grad():
    sample_queries = eval_queries[:5]
    sample_items = eval_items[:100] 
    
    mol_sims = mol_model(sample_queries, sample_items)
    dot_sims = dot_model(sample_queries, sample_items)
    
    axes[1, 0].hist(mol_sims.cpu().flatten().numpy(), bins=50, alpha=0.7, label='MoL', color='green')
    axes[1, 0].hist(dot_sims.cpu().flatten().numpy(), bins=50, alpha=0.7, label='Dot Product', color='orange')
    axes[1, 0].set_title('Similarity Score Distribution')
    axes[1, 0].set_xlabel('Similarity Score')
    axes[1, 0].set_ylabel('Frequency')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)

# Component utilization (if we can extract gating weights)
try:
    with torch.no_grad():
        sample_features = torch.cat([sample_queries[0], sample_items[0]], dim=0)
        gating_weights = mol_model.gating_network(sample_features.unsqueeze(0)).squeeze(0)
        
        if mol_model.use_outer_product:
            gating_weights = gating_weights.view(mol_model.P_q, mol_model.P_x)
            im = axes[1, 1].imshow(gating_weights.cpu().numpy(), cmap='viridis')
            axes[1, 1].set_title('Component Gating Weights (Outer Product)')
            axes[1, 1].set_xlabel('Item Components')
            axes[1, 1].set_ylabel('Query Components')
            plt.colorbar(im, ax=axes[1, 1])
        else:
            axes[1, 1].bar(range(len(gating_weights)), gating_weights.cpu().numpy())
            axes[1, 1].set_title('Component Utilization')
            axes[1, 1].set_xlabel('Component Index')
            axes[1, 1].set_ylabel('Weight')
            axes[1, 1].grid(True, alpha=0.3)
except:
    axes[1, 1].text(0.5, 0.5, 'Component weights\nnot available', ha='center', va='center', transform=axes[1, 1].transAxes)
    axes[1, 1].set_title('Component Analysis')

plt.tight_layout()
plt.show()

print("📊 Visualization completed")

## 🎯 Template cho Nghiên cứu Cá nhân

### 🔧 Customization Points:

1. **Modify Architecture**:

In [None]:
# Example: Custom MoL variant
class CustomMoLVariant(MixtureOfLogitsModule):
    def __init__(self, *args, attention_mechanism=True, **kwargs):
        super().__init__(*args, **kwargs)
        
        if attention_mechanism:
            self.attention = nn.MultiheadAttention(
                embed_dim=self.component_dim,
                num_heads=4,
                batch_first=True
            )
    
    def forward(self, queries, items):
        # Add your custom logic here
        return super().forward(queries, items)

print("🎯 Custom variant template defined")

2. **Experiment with Different Datasets**:

In [None]:
def load_custom_dataset(dataset_path: str):
    """
    Template for loading custom datasets
    
    Replace this with your data loading logic:
    - MovieLens for recommendation
    - MS MARCO for QA
    - Your domain-specific data
    """
    # TODO: Implement your data loading
    pass

def create_domain_specific_features(raw_data):
    """
    Create domain-specific features for your use case
    """
    # TODO: Feature engineering
    pass

print("📁 Dataset templates defined")

3. **Integration với LangChain**:

In [None]:
if LANGCHAIN_AVAILABLE:
    class MoLEmbeddings(Embeddings):
        """
        LangChain-compatible MoL embeddings
        """
        def __init__(self, mol_model: MixtureOfLogitsModule):
            self.mol_model = mol_model
        
        def embed_documents(self, texts: List[str]) -> List[List[float]]:
            # TODO: Implement document embedding using MoL
            pass
        
        def embed_query(self, text: str) -> List[float]:
            # TODO: Implement query embedding using MoL
            pass
    
    print("🦜 LangChain integration template ready")
else:
    print("⚠️ LangChain not available - skipping integration")

## 🏁 Kết luận và Hướng phát triển

### 📋 Tóm tắt Implementation:

✅ **Đã triển khai**:
- Mixture-of-Logits (MoL) architecture với outer product optimization
- Load balancing loss để cải thiện component utilization
- RAILS framework với exact và approximate top-k retrieval
- Comprehensive evaluation metrics (Hit Rate, NDCG)
- Performance benchmarking và analysis

### 🎯 Key Insights từ Paper:

1. **MoL as Universal Approximator**: Có thể biểu diễn bất kỳ similarity function nào
2. **GPU Efficiency**: Tận dụng arithmetic intensity cao của modern accelerators
3. **Load Balancing**: Critical để tránh component collapse
4. **Approximate Retrieval**: Đạt 66x speedup với >99% recall

### 🚀 Hướng phát triển:

1. **Scale to Real Datasets**: Test trên MovieLens, MS MARCO
2. **Advanced Architectures**: Attention mechanisms, transformer-based gating
3. **Production Optimization**: CUDA kernels, quantization
4. **Multi-modal Extensions**: Text, image, audio embeddings

### 📚 References:

- **Paper**: Ding, B., & Zhai, J. (2025). Retrieval with Learned Similarities. WWW 2025.
- **Code**: https://github.com/bailuding/rails
- **Implementation**: Complete trong notebook này với educational focus
