# Lecture 3: Two-Tower Retrieval - Interactive Notebook

> **Companion to**: 03-two-tower-retrieval.md
> **Run time**: ~5 minutes

This notebook implements the two-tower retrieval model from X's recommendation algorithm.

**Key concepts covered:**
- Two-stage design (user tower + candidate tower)
- Signed action embeddings (2x-1 trick)
- Retrieval via matmul + top-k
- Asymmetric design rationale

---

## 1. Imports + Setup

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Configuration
EMB_SIZE = 64
HIDDEN_SIZE = 128
NUM_ACTIONS = 19
NUM_LAYERS = 2
NUM_HEADS = 4
KEY_SIZE = 16
VOCAB_SIZE = 10000

## 2. ActionEmbedding Module

Converts multi-hot action vectors to embeddings using the 2x-1 signed trick.

In [None]:
class ActionEmbedding(nn.Module):
    """
    Convert multi-hot action vectors to embeddings.
    
    Uses the '2x-1' trick: maps {0, 1} -> {-1, +1}
    This allows the model to distinguish between "not performed" and "performed".
    
    From recsys_retrieval_model.py:161-184
    """
    
    def __init__(self, num_actions: int, emb_size: int):
        super().__init__()
        self.num_actions = num_actions
        self.emb_size = emb_size
        
        # Action projection matrix: [num_actions, emb_size]
        self.proj = nn.Linear(num_actions, emb_size, bias=False)
    
    def forward(self, actions: torch.Tensor) -> torch.Tensor:
        """
        Args:
            actions: [B, S, num_actions] multi-hot vector (0 or 1)
        
        Returns:
            action_emb: [B, S, emb_size] action embeddings
        """
        B, S, _ = actions.shape
        
        # Apply 2x-1 trick: {0, 1} -> {-1, +1}
        actions_signed = (2 * actions - 1).float()  # [B, S, num_actions]
        
        # Project to embedding space
        action_emb = self.proj(actions_signed)  # [B, S, emb_size]
        
        # Mask out invalid positions (all zeros = no actions)
        valid_mask = (actions.sum(dim=-1, keepdim=True) > 0)  # [B, S, 1]
        action_emb = action_emb * valid_mask
        
        return action_emb

# Test
print("Testing ActionEmbedding...")
action_emb_module = ActionEmbedding(NUM_ACTIONS, EMB_SIZE).to(device)

# Create dummy action vectors
actions = torch.randint(0, 2, (4, 10, NUM_ACTIONS)).float().to(device)
print(f"Actions shape: {actions.shape}")  # [4, 10, 19]

# Get embeddings
emb = action_emb_module(actions)
print(f"Action embeddings shape: {emb.shape}")  # [4, 10, 64]

# Verify the 2x-1 trick
print(f"Sample actions[0,0]: {actions[0,0,:5].tolist()}")
print(f"After 2x-1: {(2*actions[0,0,:5]-1).tolist()}")
print(f"✓ ActionEmbedding works!\n")

## 3. CandidateTower Module

Projects post+author embeddings to a shared embedding space using an expand-compress MLP.

In [None]:
class CandidateTower(nn.Module):
    """
    Candidate tower that projects post+author embeddings to a shared embedding space.
    
    Architecture: Expand -> Compress (same as X's implementation)
    - Flatten: [B, num_hashes, D] -> [B, num_hashes * D]
    - Expand: [B, num_hashes * D] -> [B, emb_size * 2]  
    - Compress: [B, emb_size * 2] -> [B, emb_size]
    - L2 normalize
    
    From recsys_retrieval_model.py:47-99
    """
    
    def __init__(self, emb_size: int, hidden_size: int = None):
        super().__init__()
        self.emb_size = emb_size
        self.hidden_size = hidden_size or emb_size * 2
        
        # Expand projection: [num_hashes * D] -> [2 * D]
        self.expand = nn.Linear(emb_size, self.hidden_size, bias=False)
        
        # Compress projection: [2 * D] -> [D]
        self.compress = nn.Linear(self.hidden_size, emb_size, bias=False)
    
    def forward(self, embeddings: torch.Tensor) -> torch.Tensor:
        """
        Args:
            embeddings: [B, num_hashes, D] or [B, D] (flattened)
        
        Returns:
            candidate_emb: [B, D] L2-normalized candidate embedding
        """
        # Flatten if needed
        if embeddings.dim() == 3:
            B, num_hashes, D = embeddings.shape
            x = embeddings.reshape(B, -1)  # [B, num_hashes * D]
        else:
            x = embeddings
        
        # Expand: [B, num_hashes * D] -> [B, 2 * D]
        x = F.silu(self.expand(x))
        
        # Compress: [B, 2 * D] -> [B, D]
        x = self.compress(x)
        
        # L2 normalize (cosine similarity requires unit vectors)
        x = F.normalize(x, p=2, dim=-1)
        
        return x

# Test
print("Testing CandidateTower...")
candidate_tower = CandidateTower(EMB_SIZE, HIDDEN_SIZE).to(device)

# Create dummy embeddings (post + author concatenated)
num_hashes = 4  # 2 post hashes + 2 author hashes
post_author_emb = torch.randn(4, num_hashes, EMB_SIZE).to(device)
print(f"Input shape: {post_author_emb.shape}")  # [4, 4, 64]

# Get candidate embeddings
cand_emb = candidate_tower(post_author_emb)
print(f"Candidate embedding shape: {cand_emb.shape}")  # [4, 64]

# Verify L2 normalization
norms = torch.norm(cand_emb, dim=-1)
print(f"L2 norms (should be ~1.0): {norms.tolist()}")
print(f"✓ CandidateTower works!\n")

## 4. UserTower Module

Encodes user features + history using a transformer + mean pooling.

In [None]:
class UserTower(nn.Module):
    """
    User tower that encodes user features + history into a single representation.
    
    Architecture:
    1. Embed user features (hash embeddings)
    2. Embed history (4 ingredients per position)
    3. Concatenate: [user] + [history]
    4. Transformer encoder
    5. Mean pool all positions
    6. L2 normalize
    
    From recsys_retrieval_model.py:206-276
    """
    
    def __init__(self, 
                 emb_size: int,
                 num_actions: int,
                 num_layers: int = NUM_LAYERS,
                 num_heads: int = NUM_HEADS):
        super().__init__()
        self.emb_size = emb_size
        self.num_actions = num_actions
        
        # User embedding (hash-based)
        self.user_emb = nn.Linear(emb_size, emb_size, bias=False)  # Simplified
        
        # Action embeddings
        self.action_emb = ActionEmbedding(num_actions, emb_size)
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=emb_size,
            nhead=num_heads,
            dim_feedforward=emb_size * 4,
            activation='silu',
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
    
    def forward(self,
                user_embedding: torch.Tensor,
                history_embeddings: torch.Tensor,
                padding_mask: torch.Tensor = None) -> torch.Tensor:
        """
        Args:
            user_embedding: [B, 1, D] user token
            history_embeddings: [B, S, D] history sequence
            padding_mask: [B, S] boolean mask (True = valid)
        
        Returns:
            user_repr: [B, D] L2-normalized user representation
        """
        B, S, D = history_embeddings.shape
        
        # Concatenate: [user] + [history]
        # [B, 1, D] + [B, S, D] -> [B, 1+S, D]
        sequence = torch.cat([user_embedding, history_embeddings], dim=1)
        
        # Create full padding mask (user token is always valid)
        if padding_mask is not None:
            user_valid = torch.ones(B, 1, dtype=torch.bool, device=padding_mask.device)
            full_mask = torch.cat([user_valid, padding_mask], dim=1)  # [B, 1+S]
        else:
            full_mask = None
        
        # Transformer
        if full_mask is not None:
            # Convert bool mask to attention mask (True = ignore)
            mask = ~full_mask
            output = self.transformer(sequence, src_key_padding_mask=mask)
        else:
            output = self.transformer(sequence)
        
        # Mean pool all positions (user + history)
        if full_mask is not None:
            # Weight by validity
            weights = full_mask.float()  # [B, 1+S]
            weights = weights / (weights.sum(dim=-1, keepdim=True) + 1e-8)
            user_repr = (output * weights.unsqueeze(-1)).sum(dim=1)  # [B, D]
        else:
            user_repr = output.mean(dim=1)  # [B, D]
        
        # L2 normalize
        user_repr = F.normalize(user_repr, p=2, dim=-1)
        
        return user_repr

# Test
print("Testing UserTower...")
user_tower = UserTower(EMB_SIZE, NUM_ACTIONS, NUM_LAYERS, NUM_HEADS).to(device)

# Create dummy inputs
B, S = 4, 10
user_embedding = torch.randn(B, 1, EMB_SIZE).to(device)
history_embeddings = torch.randn(B, S, EMB_SIZE).to(device)
padding_mask = torch.ones(B, S, dtype=torch.bool).to(device)
padding_mask[:, S//2:] = False  # Second half is padding

print(f"User embedding: {user_embedding.shape}")
print(f"History embeddings: {history_embeddings.shape}")
print(f"Padding mask: {padding_mask.shape}")

# Forward pass
user_repr = user_tower(user_embedding, history_embeddings, padding_mask)
print(f"User representation: {user_repr.shape}")

# Verify L2 normalization
norms = torch.norm(user_repr, dim=-1)
print(f"L2 norms (should be ~1.0): {norms.tolist()}")
print(f"✓ UserTower works!\n")

## 5. TwoTowerRetrieval Module

Combines both towers for retrieval: user embedding x candidate embedding -> top-k.

In [None]:
class TwoTowerRetrieval(nn.Module):
    """
    Two-tower retrieval model for candidate generation.
    
    Architecture:
    - User Tower: Encodes user + history -> user_repr [B, D]
    - Candidate Tower: Encodes candidates -> cand_repr [B, C, D]
    - Retrieval: user_repr @ cand_repr.T -> similarity [B, C]
    
    Both towers output L2-normalized embeddings, so dot product = cosine similarity.
    """
    
    def __init__(self,
                 emb_size: int,
                 hidden_size: int,
                 num_actions: int,
                 num_layers: int = NUM_LAYERS,
                 num_heads: int = NUM_HEADS):
        super().__init__()
        self.emb_size = emb_size
        
        self.candidate_tower = CandidateTower(emb_size, hidden_size)
        self.user_tower = UserTower(emb_size, num_actions, num_layers, num_heads)
    
    def forward(self,
                user_embedding: torch.Tensor,
                history_embeddings: torch.Tensor,
                padding_mask: torch.Tensor,
                candidate_embeddings: torch.Tensor) -> torch.Tensor:
        """
        Args:
            user_embedding: [B, 1, D] user token
            history_embeddings: [B, S, D] history sequence
            padding_mask: [B, S] boolean mask (True = valid)
            candidate_embeddings: [B, C, num_hashes, D] candidate embeddings
        
        Returns:
            scores: [B, C] similarity scores (before softmax)
        """
        # User tower -> user_repr [B, D]
        user_repr = self.user_tower(user_embedding, history_embeddings, padding_mask)
        
        # Candidate tower -> cand_repr [B, C, D]
        cand_repr = self.candidate_tower(candidate_embeddings)
        
        # Dot product similarity (both are L2-normalized)
        # [B, D] @ [D, C] -> [B, C]  (but cand_repr is [B, C, D])
        scores = torch.bmm(user_repr.unsqueeze(1), cand_repr.transpose(1, 2)).squeeze(1)
        # Equivalent to: (user_repr * cand_repr).sum(dim=-1)
        
        return scores
    
    def retrieve_top_k(self,
                       user_repr: torch.Tensor,
                       corpus_embeddings: torch.Tensor,
                       k: int) -> tuple:
        """
        Retrieve top-k candidates from a corpus.
        
        Args:
            user_repr: [B, D] user representation
            corpus_embeddings: [N, D] pre-computed candidate embeddings
            k: number of candidates to retrieve
        
        Returns:
            top_k_indices: [B, k] indices of top-k candidates
            top_k_scores: [B, k] similarity scores
        """
        # Compute similarity: [B, D] @ [N, D].T -> [B, N]
        scores = torch.mm(user_repr, corpus_embeddings.T)
        
        # Top-k
        top_k_scores, top_k_indices = torch.topk(scores, k, dim=-1)
        
        return top_k_indices, top_k_scores

# Test
print("Testing TwoTowerRetrieval...")
model = TwoTowerRetrieval(EMB_SIZE, HIDDEN_SIZE, NUM_ACTIONS, NUM_LAYERS, NUM_HEADS).to(device)

# Create dummy data
B, S, C = 4, 10, 20
num_hashes = 4

user_embedding = torch.randn(B, 1, EMB_SIZE).to(device)
history_embeddings = torch.randn(B, S, EMB_SIZE).to(device)
padding_mask = torch.ones(B, S, dtype=torch.bool).to(device)
candidate_embeddings = torch.randn(B, C, num_hashes, EMB_SIZE).to(device)

# Forward pass
scores = model(user_embedding, history_embeddings, padding_mask, candidate_embeddings)
print(f"Scores shape: {scores.shape}")  # [B, C]
print(f"Score range: [{scores.min():.3f}, {scores.max():.3f}]")

# Verify L2 normalization
user_repr = model.user_tower(user_embedding, history_embeddings, padding_mask)
cand_repr = model.candidate_tower(candidate_embeddings)
user_norms = torch.norm(user_repr, dim=-1)
cand_norms = torch.norm(cand_repr, dim=-1)
print(f"User repr norms: {user_norms.mean():.4f}")
print(f"Candidate repr norms: {cand_norms.mean():.4f}")
print(f"✓ TwoTowerRetrieval works!\n")

## 6. Demo: Retrieval with Synthetic Corpus

In [None]:
print("="*60)
print("DEMO: Two-Tower Retrieval")
print("="*60)

# Create model
model = TwoTowerRetrieval(EMB_SIZE, HIDDEN_SIZE, NUM_ACTIONS, NUM_LAYERS, NUM_HEADS).to(device)
model.eval()

# Create synthetic users
B = 3  # 3 users
user_embedding = torch.randn(B, 1, EMB_SIZE).to(device)
history_embeddings = torch.randn(B, 10, EMB_SIZE).to(device)
padding_mask = torch.ones(B, 10, dtype=torch.bool).to(device)

# Create synthetic corpus
N = 1000  # 1000 items in corpus
corpus_embeddings = torch.randn(N, EMB_SIZE).to(device)
corpus_embeddings = F.normalize(corpus_embeddings, p=2, dim=-1)  # L2 normalize

# Get user representations
user_repr = model.user_tower(user_embedding, history_embeddings, padding_mask)
print(f"\n1. User representations: {user_repr.shape}")
print(f"   L2 norms: {torch.norm(user_repr, dim=-1).tolist()}")

# Retrieve top-5 candidates for each user
top_k_indices, top_k_scores = model.retrieve_top_k(user_repr, corpus_embeddings, k=5)

print(f"\n2. Top-5 retrievals per user:")
for i in range(B):
    print(f"   User {i}: indices={top_k_indices[i].tolist()}")
    print(f"            scores={top_k_scores[i].tolist()}")

# Verify that norms are ~1.0
print(f"\n3. Norm verification:")
print(f"   User repr norms: {torch.norm(user_repr, dim=-1)}")
print(f"   Corpus norms: {torch.norm(corpus_embeddings, dim=-1)[:5].tolist()}")

# Visualization: retrieval scores distribution
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
for i in range(B):
    all_scores = torch.mm(user_repr[i:i+1], corpus_embeddings.T).squeeze()
    axes[i].hist(all_scores.cpu().numpy(), bins=50, alpha=0.7, edgecolor='black')
    axes[i].axvline(top_k_scores[i, 0].item(), color='red', linestyle='--', 
                    label=f'Top-1: {top_k_scores[i, 0]:.3f}')
    axes[i].set_title(f'User {i}')
    axes[i].set_xlabel('Similarity Score')
    axes[i].set_ylabel('Count')
    axes[i].legend()
plt.suptitle('Retrieval Score Distributions (corpus of 1000 items)')
plt.tight_layout()
plt.show()

print(f"\n✓ Demo complete! Users retrieved relevant candidates from corpus.\n")

## Summary

You've implemented the two-tower retrieval model from X's recommendation algorithm!

**Key takeaways:**

1. **Two-tower design**: User tower and candidate tower are completely separate
   - Allows pre-computing candidate embeddings offline
   - User tower runs online for each request

2. **Asymmetric architecture**:
   - User tower: Transformer (complex, learns user preferences)
   - Candidate tower: Simple MLP (just projects embeddings)

3. **Signed action embeddings**: 2x-1 trick distinguishes "not performed" from "performed"

4. **Efficient retrieval**: Dot product + top-k enables billion-scale search with ANN

5. **L2 normalization**: Both towers output unit vectors, so dot product = cosine similarity

**Next up**: [Lecture 4 - Ranking Model](04-ranking-model.ipynb)