# Lecture 2: Hash Embeddings - Interactive Notebook

> **Companion to**: 02-hash-embeddings.md
> **Run time**: ~5 minutes

This notebook implements the hash embedding system from X's recommendation algorithm.

**Key concepts covered:**
- Multi-hash collision avoidance
- Block user/history/candidate reduction
- RecsysBatch/RecsysEmbeddings data containers
- Padding masks

---

## 1. Imports + Setup

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from dataclasses import dataclass
from typing import Optional, Tuple

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Configuration (matches X's HashConfig from recsys_model.py)
EMB_SIZE = 64
NUM_USER_HASHES = 2
NUM_ITEM_HASHES = 2
NUM_AUTHOR_HASHES = 2
VOCAB_SIZE = 100000  # Simulating 100K entities

## 2. MultiHashEmbedding Module

Simulates hash tables with lookup. In production, X uses actual hash functions to map entity IDs to embedding table indices. Here we simulate with random projections.

In [None]:
class MultiHashEmbedding(nn.Module):
    """
    Multi-hash embedding table that reduces collisions.
    
    Instead of one embedding table, we use NUM_HASHES separate tables.
    Each entity gets embeddings from each table, then we combine them.
    
    This reduces collision probability from 1/vocab_size to 1/vocab_size^NUM_HASHES.
    """
    
    def __init__(self, vocab_size: int, emb_size: int, num_hashes: int):
        super().__init__()
        self.vocab_size = vocab_size
        self.emb_size = emb_size
        self.num_hashes = num_hashes
        
        # Create separate embedding tables for each hash function
        # Shape: [num_hashes, vocab_size, emb_size]
        self.embeddings = nn.Parameter(
            torch.randn(num_hashes, vocab_size, emb_size) * 0.1
        )
    
    def forward(self, hash_indices: torch.Tensor) -> torch.Tensor:
        """
        Look up embeddings for given hash indices.
        
        Args:
            hash_indices: [B, num_hashes] tensor of hash indices
        
        Returns:
            embeddings: [B, num_hashes, emb_size] tensor of looked-up embeddings
        """
        B, num_hashes = hash_indices.shape
        assert num_hashes == self.num_hashes, \
            f"Expected {self.num_hashes} hash indices, got {num_hashes}"
        
        # Look up embeddings from each hash table
        # We iterate over hash functions and gather the corresponding embeddings
        all_embeddings = []
        for h in range(self.num_hashes):
            # Get indices for this hash function
            indices = hash_indices[:, h]  # [B]
            # Look up: [B, emb_size]
            emb = self.embeddings[h, indices, :]
            all_embeddings.append(emb)
        
        # Stack: [B, num_hashes, emb_size]
        return torch.stack(all_embeddings, dim=1)

# Test the module
print("Testing MultiHashEmbedding...")
hash_emb = MultiHashEmbedding(VOCAB_SIZE, EMB_SIZE, NUM_USER_HASHES).to(device)

# Create dummy hash indices (e.g., user hashes)
user_hashes = torch.randint(1, VOCAB_SIZE, (4, NUM_USER_HASHES)).to(device)
print(f"User hashes shape: {user_hashes.shape}")  # [4, 2]

# Look up embeddings
user_embeddings = hash_emb(user_hashes)
print(f"User embeddings shape: {user_embeddings.shape}")  # [4, 2, 64]
print(f"✓ MultiHashEmbedding works!\n")

## 3. BlockUserReduce Module

Flattens and projects multiple user hash embeddings into a single user representation.

**From recsys_model.py:79-119**

In [None]:
class BlockUserReduce(nn.Module):
    """
    Combine multiple user hash embeddings into a single user representation.
    
    Process:
    1. Flatten: [B, num_hashes, D] -> [B, num_hashes * D]
    2. Project: [B, num_hashes * D] -> [B, 1, D] via linear projection
    3. Create padding mask based on first hash (hash value 0 = padding)
    """
    
    def __init__(self, num_hashes: int, emb_size: int):
        super().__init__()
        self.num_hashes = num_hashes
        self.emb_size = emb_size
        
        # Projection matrix: [num_hashes * D, D]
        self.projection = nn.Linear(num_hashes * emb_size, emb_size, bias=False)
    
    def forward(self, 
                user_hashes: torch.Tensor,
                user_embeddings: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            user_hashes: [B, num_hashes] hash indices (0 = padding)
            user_embeddings: [B, num_hashes, D] looked-up embeddings
        
        Returns:
            user_embedding: [B, 1, D] combined user representation
            user_padding_mask: [B, 1] boolean mask (True = valid)
        """
        B = user_embeddings.shape[0]
        
        # Step 1: Flatten
        # [B, num_hashes, D] -> [B, num_hashes * D]
        user_flat = user_embeddings.reshape(B, -1)
        
        # Step 2: Project
        # [B, num_hashes * D] -> [B, D]
        user_proj = self.projection(user_flat)
        
        # [B, D] -> [B, 1, D]
        user_embedding = user_proj.unsqueeze(1)
        
        # Step 3: Create padding mask
        # Hash 0 is reserved for padding
        user_padding_mask = (user_hashes[:, 0] != 0).unsqueeze(1)  # [B, 1]
        
        return user_embedding, user_padding_mask

# Test the module
print("Testing BlockUserReduce...")
block_user_reduce = BlockUserReduce(NUM_USER_HASHES, EMB_SIZE).to(device)

# Create test data
user_hashes = torch.randint(1, VOCAB_SIZE, (4, NUM_USER_HASHES)).to(device)
user_embeddings = hash_emb(user_hashes)

# Apply block user reduce
user_emb, user_mask = block_user_reduce(user_hashes, user_embeddings)

print(f"Input user hashes shape: {user_hashes.shape}")  # [4, 2]
print(f"Input user embeddings shape: {user_embeddings.shape}")  # [4, 2, 64]
print(f"Output user embedding shape: {user_emb.shape}")  # [4, 1, 64]
print(f"Output user padding mask shape: {user_mask.shape}")  # [4, 1]
print(f"User padding mask: {user_mask.flatten()}")
print(f"✓ BlockUserReduce works!\n")

## 4. BlockHistoryReduce Module

Combines 4 ingredients for each history item into a single embedding:
1. Post embeddings (from hash)
2. Author embeddings (from hash)
3. Product surface embeddings (learned lookup)
4. Action embeddings (multi-hot)

**From recsys_model.py:122-182**

In [None]:
class BlockHistoryReduce(nn.Module):
    """
    Combine history embeddings (post, author, actions, product_surface) into sequence.
    
    For each history position:
    1. Flatten post embeddings: [B, S, num_item_hashes, D] -> [B, S, num_item_hashes * D]
    2. Flatten author embeddings: [B, S, num_author_hashes, D] -> [B, S, num_author_hashes * D]
    3. Concatenate: post + author + actions + product_surface
    4. Project to single embedding
    """
    
    def __init__(self, 
                 num_item_hashes: int,
                 num_author_hashes: int,
                 emb_size: int,
                 num_actions: int = 19):
        super().__init__()
        self.num_item_hashes = num_item_hashes
        self.num_author_hashes = num_author_hashes
        self.emb_size = emb_size
        
        # Input dimension for projection
        input_dim = (
            num_item_hashes * emb_size +  # Post embeddings
            num_author_hashes * emb_size +  # Author embeddings
            emb_size +  # Action embeddings
            emb_size  # Product surface embeddings
        )
        
        self.projection = nn.Linear(input_dim, emb_size, bias=False)
        
        # Product surface embedding table (learned)
        self.product_surface_emb = nn.Embedding(16, emb_size)  # 16 surface types
        
        # Action embedding projection (for multi-hot actions)
        self.action_proj = nn.Linear(num_actions, emb_size, bias=False)
    
    def forward(self,
                history_post_hashes: torch.Tensor,
                history_post_embeddings: torch.Tensor,
                history_author_embeddings: torch.Tensor,
                history_product_surface: torch.Tensor,
                history_actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            history_post_hashes: [B, S, num_item_hashes] hash indices
            history_post_embeddings: [B, S, num_item_hashes, D] post embeddings
            history_author_embeddings: [B, S, num_author_hashes, D] author embeddings
            history_product_surface: [B, S] product surface indices
            history_actions: [B, S, num_actions] multi-hot action vectors
        
        Returns:
            history_embeddings: [B, S, D] combined history sequence
            history_padding_mask: [B, S] boolean mask (True = valid)
        """
        B, S, _, D = history_post_embeddings.shape
        
        # Step 1: Flatten post embeddings
        # [B, S, num_item_hashes, D] -> [B, S, num_item_hashes * D]
        post_flat = history_post_embeddings.reshape(B, S, -1)
        
        # Step 2: Flatten author embeddings
        # [B, S, num_author_hashes, D] -> [B, S, num_author_hashes * D]
        author_flat = history_author_embeddings.reshape(B, S, -1)
        
        # Step 3: Embed actions (signed multi-hot trick)
        # 2 * actions - 1 maps {0,1} to {-1, +1}
        actions_signed = (2 * history_actions - 1).float()  # [B, S, num_actions]
        action_emb = self.action_proj(actions_signed)  # [B, S, D]
        
        # Mask out invalid positions (no actions)
        valid_mask = (history_actions.sum(dim=-1) > 0).unsqueeze(-1)  # [B, S, 1]
        action_emb = action_emb * valid_mask
        
        # Step 4: Product surface embeddings
        surface_emb = self.product_surface_emb(history_product_surface)  # [B, S, D]
        
        # Step 5: Concatenate all ingredients
        combined = torch.cat([post_flat, author_flat, action_emb, surface_emb], dim=-1)
        # [B, S, num_item_hashes*D + num_author_hashes*D + D + D]
        
        # Step 6: Project to single embedding
        history_emb = self.projection(combined)  # [B, S, D]
        
        # Step 7: Create padding mask (based on first post hash)
        history_padding_mask = (history_post_hashes[:, :, 0] != 0)  # [B, S]
        
        return history_emb, history_padding_mask

# Test the module
print("Testing BlockHistoryReduce...")
block_history_reduce = BlockHistoryReduce(
    NUM_ITEM_HASHES, NUM_AUTHOR_HASHES, EMB_SIZE, num_actions=19
).to(device)

# Create test history data
B, S = 4, 10  # Batch size 4, sequence length 10
history_post_hashes = torch.randint(0, VOCAB_SIZE, (B, S, NUM_ITEM_HASHES)).to(device)
history_post_hashes[:, S//2:, :] = 0  # Zero out second half (padding)

history_post_embeddings = torch.randn(B, S, NUM_ITEM_HASHES, EMB_SIZE).to(device)
history_author_embeddings = torch.randn(B, S, NUM_AUTHOR_HASHES, EMB_SIZE).to(device)
history_product_surface = torch.randint(0, 16, (B, S)).to(device)
history_actions = torch.zeros(B, S, 19).to(device)
history_actions[:, :S//2, :] = torch.randint(0, 2, (B, S//2, 19)).to(device)  # First half valid

# Apply block history reduce
history_emb, history_mask = block_history_reduce(
    history_post_hashes,
    history_post_embeddings,
    history_author_embeddings,
    history_product_surface,
    history_actions
)

print(f"History sequence length: {S}")
print(f"Output history embeddings shape: {history_emb.shape}")  # [4, 10, 64]
print(f"Output history padding mask shape: {history_mask.shape}")  # [4, 10]
print(f"Valid positions per sample: {history_mask.sum(dim=1).tolist()}")
print(f"✓ BlockHistoryReduce works!\n")

## 5. RecsysBatch and RecsysEmbeddings Data Containers

Data structures that hold the input features and embeddings.

In [None]:
from dataclasses import dataclass

@dataclass
class RecsysBatch:
    """
    Input batch for the recommendation model.
    
    Contains the feature data (hashes, actions, product surfaces) but NOT the embeddings.
    Embeddings are passed separately via RecsysEmbeddings.
    """
    user_hashes: torch.Tensor              # [B, num_user_hashes]
    history_post_hashes: torch.Tensor      # [B, S, num_item_hashes]
    history_author_hashes: torch.Tensor    # [B, S, num_author_hashes]
    history_actions: torch.Tensor          # [B, S, num_actions]
    history_product_surface: torch.Tensor  # [B, S]
    candidate_post_hashes: torch.Tensor    # [B, C, num_item_hashes]
    candidate_author_hashes: torch.Tensor  # [B, C, num_author_hashes]
    candidate_product_surface: torch.Tensor # [B, C]

@dataclass
class RecsysEmbeddings:
    """
    Container for pre-looked-up embeddings from the embedding tables.
    
    These embeddings are looked up from hash tables before being passed to the model.
    The block_*_reduce functions will combine multiple hash embeddings into single representations.
    """
    user_embeddings: torch.Tensor          # [B, num_user_hashes, D]
    history_post_embeddings: torch.Tensor  # [B, S, num_item_hashes, D]
    history_author_embeddings: torch.Tensor # [B, S, num_author_hashes, D]
    candidate_post_embeddings: torch.Tensor # [B, C, num_item_hashes, D]
    candidate_author_embeddings: torch.Tensor # [B, C, num_author_hashes, D]

print("Data containers defined.")

## 6. Demo: End-to-End Hash Embedding Pipeline

Create synthetic data and run through the hash embedding system.

In [None]:
print("="*60)
print("DEMO: Hash Embedding Pipeline")
print("="*60)

# Configuration
batch_size = 2
history_len = 5
num_candidates = 3

# Create hash embedding tables
user_hash_table = MultiHashEmbedding(VOCAB_SIZE, EMB_SIZE, NUM_USER_HASHES).to(device)
post_hash_table = MultiHashEmbedding(VOCAB_SIZE, EMB_SIZE, NUM_ITEM_HASHES).to(device)
author_hash_table = MultiHashEmbedding(VOCAB_SIZE, EMB_SIZE, NUM_AUTHOR_HASHES).to(device)

# Create block reducers
block_user = BlockUserReduce(NUM_USER_HASHES, EMB_SIZE).to(device)
block_history = BlockHistoryReduce(NUM_ITEM_HASHES, NUM_AUTHOR_HASHES, EMB_SIZE).to(device)

# Generate synthetic hash indices
user_hashes = torch.randint(1, VOCAB_SIZE, (batch_size, NUM_USER_HASHES)).to(device)
history_post_hashes = torch.randint(1, VOCAB_SIZE, (batch_size, history_len, NUM_ITEM_HASHES)).to(device)
history_author_hashes = torch.randint(1, VOCAB_SIZE, (batch_size, history_len, NUM_AUTHOR_HASHES)).to(device)
candidate_post_hashes = torch.randint(1, VOCAB_SIZE, (batch_size, num_candidates, NUM_ITEM_HASHES)).to(device)
candidate_author_hashes = torch.randint(1, VOCAB_SIZE, (batch_size, num_candidates, NUM_AUTHOR_HASHES)).to(device)

# Other features
history_product_surface = torch.randint(0, 16, (batch_size, history_len)).to(device)
history_actions = torch.randint(0, 2, (batch_size, history_len, 19)).float().to(device)
candidate_product_surface = torch.randint(0, 16, (batch_size, num_candidates)).to(device)

print(f"\n1. Input shapes:")
print(f"   User hashes: {user_hashes.shape}")
print(f"   History post hashes: {history_post_hashes.shape}")
print(f"   History author hashes: {history_author_hashes.shape}")
print(f"   Candidate post hashes: {candidate_post_hashes.shape}")
print(f"   Candidate author hashes: {candidate_author_hashes.shape}")

# Step 1: Look up embeddings from hash tables
print(f"\n2. Looking up embeddings from hash tables...")
user_embeddings = user_hash_table(user_hashes)  # [B, num_user_hashes, D]
history_post_embeddings = post_hash_table(
    history_post_hashes.reshape(-1, NUM_ITEM_HASHES)
).reshape(batch_size, history_len, NUM_ITEM_HASHES, EMB_SIZE)
history_author_embeddings = author_hash_table(
    history_author_hashes.reshape(-1, NUM_AUTHOR_HASHES)
).reshape(batch_size, history_len, NUM_AUTHOR_HASHES, EMB_SIZE)
candidate_post_embeddings = post_hash_table(
    candidate_post_hashes.reshape(-1, NUM_ITEM_HASHES)
).reshape(batch_size, num_candidates, NUM_ITEM_HASHES, EMB_SIZE)
candidate_author_embeddings = author_hash_table(
    candidate_author_hashes.reshape(-1, NUM_AUTHOR_HASHES)
).reshape(batch_size, num_candidates, NUM_AUTHOR_HASHES, EMB_SIZE)

print(f"   User embeddings: {user_embeddings.shape}")
print(f"   History post embeddings: {history_post_embeddings.shape}")
print(f"   History author embeddings: {history_author_embeddings.shape}")
print(f"   Candidate post embeddings: {candidate_post_embeddings.shape}")
print(f"   Candidate author embeddings: {candidate_author_embeddings.shape}")

# Step 2: Reduce embeddings
print(f"\n3. Applying block reduction...")
user_emb, user_mask = block_user(user_hashes, user_embeddings)
history_emb, history_mask = block_history(
    history_post_hashes,
    history_post_embeddings,
    history_author_embeddings,
    history_product_surface,
    history_actions
)

print(f"   User representation: {user_emb.shape}")
print(f"   User padding mask: {user_mask.shape}")
print(f"   History representation: {history_emb.shape}")
print(f"   History padding mask: {history_mask.shape}")

# Step 3: Verify shapes
assert user_emb.shape == (batch_size, 1, EMB_SIZE)
assert user_mask.shape == (batch_size, 1)
assert history_emb.shape == (batch_size, history_len, EMB_SIZE)
assert history_mask.shape == (batch_size, history_len)

print(f"\n✓ All shape assertions passed!")
print(f"\nFinal output:")
print(f"  User token: {user_emb.shape} (ready for transformer)")
print(f"  History sequence: {history_emb.shape} (ready for transformer)")
print(f"\nThese would be concatenated and fed into the transformer!\n")

## 7. Visualization: Collision Probability vs Number of Hashes

Show how using multiple hash functions reduces collision probability.

In [None]:
# Compute collision probability
vocab_sizes = [1000, 10000, 100000, 1000000]
num_hashes_range = range(1, 6)

plt.figure(figsize=(10, 6))

for vocab_size in vocab_sizes:
    collision_probs = []
    for n in num_hashes_range:
        # Probability of collision = (1/vocab_size)^n
        prob = (1.0 / vocab_size) ** n
        collision_probs.append(prob)
    
    # Plot (on log scale)
    plt.plot(list(num_hashes_range), collision_probs, 
             marker='o', label=f'Vocab={vocab_size:,}')

plt.yscale('log')
plt.xlabel('Number of Hash Functions', fontsize=12)
plt.ylabel('Collision Probability (log scale)', fontsize=12)
plt.title('Hash Collision Probability vs Number of Hash Functions\n(More hashes = exponentially fewer collisions)', fontsize=14)
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.show()

print("Key insight: With just 2 hash functions, collision probability drops by VOCAB_SIZE^2!")
print("Example: For vocab_size=100,000:")
print(f"  1 hash: P(collision) = {1/100000:.6f}")
print(f"  2 hashes: P(collision) = {(1/100000)**2:.12f}")
print(f"  3 hashes: P(collision) = {(1/100000)**3:.18f}")

## Summary

You've implemented the hash embedding system from X's recommendation algorithm!

**Key takeaways:**

1. **Multi-hash design**: Use multiple hash functions to dramatically reduce collision probability

2. **Block reduction**: Flatten + project to combine multiple embeddings into one:
   - `BlockUserReduce`: User hashes → single user token
   - `BlockHistoryReduce`: 4 ingredients → single history token per position

3. **Padding masks**: Track which positions are valid (hash != 0) for transformer attention

4. **Scalability**: Hash embeddings allow handling billion-entity vocabularies without a massive embedding table

**Next up**: [Lecture 3 - Two-Tower Retrieval](03-two-tower-retrieval.ipynb)