# t-SNE from Scratch: Understanding the Core Algorithm

**t-Distributed Stochastic Neighbor Embedding (t-SNE)** is a dimensionality reduction technique that's particularly good at preserving local structure in high-dimensional data.

## Core Idea

1. **High-dimensional space**: Convert distances between points to probability distributions (using Gaussian)
2. **Low-dimensional space**: Convert distances to probability distributions (using Student t-distribution)
3. **Optimization**: Minimize the difference (KL divergence) between these two distributions

## Why Student t-distribution in low dimensions?
- Solves the "crowding problem" - points that are moderately far apart in high dimensions can be placed farther apart in 2D/3D
- Heavier tails allow for better separation

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits, make_classification
from sklearn.preprocessing import StandardScaler
import seaborn as sns

np.random.seed(42)
sns.set_style('whitegrid')

## Step 1: Compute Pairwise Similarities in High-Dimensional Space

For each point $x_i$, we compute conditional probabilities $p_{j|i}$ that $x_i$ would pick $x_j$ as its neighbor:

$$p_{j|i} = \frac{\exp(-||x_i - x_j||^2 / 2\sigma_i^2)}{\sum_{k \neq i} \exp(-||x_i - x_k||^2 / 2\sigma_i^2)}$$

Then symmetrize: $p_{ij} = \frac{p_{j|i} + p_{i|j}}{2N}$

In [None]:
def compute_pairwise_distances(X):
    """
    Compute squared Euclidean distances between all pairs of points.
    
    Args:
        X: (n_samples, n_features) array
    
    Returns:
        distances: (n_samples, n_samples) array of squared distances
    """
    sum_X = np.sum(X**2, axis=1)
    # ||x_i - x_j||^2 = ||x_i||^2 + ||x_j||^2 - 2*x_i·x_j
    distances = sum_X[:, np.newaxis] + sum_X[np.newaxis, :] - 2 * np.dot(X, X.T)
    distances = np.maximum(distances, 0)  # numerical stability
    return distances

In [None]:
def compute_perplexity_sigma(distances, target_perplexity=30.0, tol=1e-5, max_iter=50):
    """
    Binary search to find sigma that gives target perplexity.
    
    Perplexity is a measure of the effective number of neighbors.
    It's defined as: Perplexity(P_i) = 2^(H(P_i))
    where H(P_i) is the Shannon entropy of the distribution P_i
    
    Args:
        distances: (n_samples,) array of squared distances from point i to all others
        target_perplexity: desired perplexity value
    
    Returns:
        best_sigma: the sigma value that achieves target perplexity
    """
    n = len(distances)
    beta_min = -np.inf
    beta_max = np.inf
    beta = 1.0  # beta = 1/(2*sigma^2)
    
    for iteration in range(max_iter):
        # Compute P_i with current beta
        P = np.exp(-distances * beta)
        P[n] = 0  # set diagonal to 0 (distance to self)
        sum_P = np.sum(P)
        
        if sum_P == 0:
            P = np.ones(n + 1) / n
            P[n] = 0
        else:
            P = P / sum_P
        
        # Compute entropy
        H = -np.sum(P * np.log2(P + 1e-10))
        perplexity = 2 ** H
        
        # Binary search
        perplexity_diff = perplexity - target_perplexity
        if np.abs(perplexity_diff) < tol:
            break
        
        if perplexity_diff > 0:
            beta_min = beta
            beta = (beta + beta_max) / 2 if beta_max != np.inf else beta * 2
        else:
            beta_max = beta
            beta = (beta + beta_min) / 2 if beta_min != -np.inf else beta / 2
    
    return np.sqrt(1 / (2 * beta))  # return sigma

In [None]:
def compute_p_high_dim(X, perplexity=30.0):
    """
    Compute the high-dimensional probability matrix P.
    
    Args:
        X: (n_samples, n_features) array
        perplexity: target perplexity value
    
    Returns:
        P: (n_samples, n_samples) symmetrized probability matrix
    """
    n = X.shape[0]
    distances = compute_pairwise_distances(X)
    P = np.zeros((n, n))
    
    print("Computing high-dimensional probabilities...")
    for i in range(n):
        if i % 100 == 0:
            print(f"  Processing point {i}/{n}")
        
        # Get distances from point i to all others
        dist_i = np.concatenate([distances[i, :i], distances[i, i+1:], [0]])
        
        # Find sigma for this point
        sigma = compute_perplexity_sigma(dist_i, perplexity)
        
        # Compute conditional probabilities
        P[i, :] = np.exp(-distances[i, :] / (2 * sigma**2))
        P[i, i] = 0
        P[i, :] /= np.sum(P[i, :])
    
    # Symmetrize and normalize
    P = (P + P.T) / (2 * n)
    P = np.maximum(P, 1e-12)  # numerical stability
    
    return P

## Step 2: Compute Similarities in Low-Dimensional Space

In the low-dimensional space (typically 2D), we use Student t-distribution with 1 degree of freedom:

$$q_{ij} = \frac{(1 + ||y_i - y_j||^2)^{-1}}{\sum_{k \neq l} (1 + ||y_k - y_l||^2)^{-1}}$$

This is essentially a Cauchy distribution - it has heavier tails than Gaussian.

In [None]:
def compute_q_low_dim(Y):
    """
    Compute the low-dimensional probability matrix Q using Student t-distribution.
    
    Args:
        Y: (n_samples, n_components) array of low-dimensional embeddings
    
    Returns:
        Q: (n_samples, n_samples) probability matrix
    """
    n = Y.shape[0]
    distances = compute_pairwise_distances(Y)
    
    # Student t-distribution with df=1
    Q = 1 / (1 + distances)
    np.fill_diagonal(Q, 0)
    
    # Normalize
    Q = Q / np.sum(Q)
    Q = np.maximum(Q, 1e-12)  # numerical stability
    
    return Q

## Step 3: Optimize with Gradient Descent

The cost function is the Kullback-Leibler divergence:

$$C = \sum_i \sum_j p_{ij} \log \frac{p_{ij}}{q_{ij}}$$

The gradient with respect to $y_i$ is:

$$\frac{\delta C}{\delta y_i} = 4 \sum_j (p_{ij} - q_{ij})(y_i - y_j)(1 + ||y_i - y_j||^2)^{-1}$$

We use momentum for better convergence.

In [None]:
def compute_gradient(P, Q, Y):
    """
    Compute gradient of KL divergence with respect to Y.
    
    Args:
        P: (n_samples, n_samples) high-dimensional probabilities
        Q: (n_samples, n_samples) low-dimensional probabilities
        Y: (n_samples, n_components) current low-dimensional embedding
    
    Returns:
        gradient: (n_samples, n_components) gradient
    """
    n, n_components = Y.shape
    
    # Pairwise differences
    diff = Y[:, np.newaxis, :] - Y[np.newaxis, :, :]  # (n, n, n_components)
    
    # Compute distances for Student-t kernel
    distances = compute_pairwise_distances(Y)
    inv_distances = 1 / (1 + distances)
    np.fill_diagonal(inv_distances, 0)
    
    # Gradient: 4 * sum_j (p_ij - q_ij) * (y_i - y_j) * (1 + ||y_i - y_j||^2)^(-1)
    PQ_diff = P - Q
    gradient = 4 * np.sum(
        (PQ_diff[:, :, np.newaxis] * diff * inv_distances[:, :, np.newaxis]),
        axis=1
    )
    
    return gradient

In [None]:
def tsne(X, n_components=2, perplexity=30.0, n_iter=1000, learning_rate=200.0, momentum=0.8):
    """
    t-SNE implementation from scratch.
    
    Args:
        X: (n_samples, n_features) input data
        n_components: dimension of embedding (usually 2)
        perplexity: controls effective number of neighbors (5-50 typical)
        n_iter: number of optimization iterations
        learning_rate: gradient descent learning rate
        momentum: momentum coefficient
    
    Returns:
        Y: (n_samples, n_components) low-dimensional embedding
        losses: list of KL divergence values during optimization
    """
    n = X.shape[0]
    
    # Step 1: Compute high-dimensional probabilities
    P = compute_p_high_dim(X, perplexity)
    
    # Early exaggeration: multiply P by 4 for first 250 iterations
    P_exaggerated = P * 4.0
    
    # Step 2: Initialize low-dimensional embedding
    Y = np.random.randn(n, n_components) * 1e-4
    
    # For momentum
    Y_velocity = np.zeros_like(Y)
    
    losses = []
    
    print("\nOptimizing embedding...")
    for iteration in range(n_iter):
        # Use exaggerated P for first 250 iterations
        P_current = P_exaggerated if iteration < 250 else P
        
        # Step 3: Compute low-dimensional probabilities
        Q = compute_q_low_dim(Y)
        
        # Step 4: Compute gradient
        gradient = compute_gradient(P_current, Q, Y)
        
        # Step 5: Update with momentum
        Y_velocity = momentum * Y_velocity - learning_rate * gradient
        Y = Y + Y_velocity
        
        # Center the embedding
        Y = Y - np.mean(Y, axis=0)
        
        # Compute loss (KL divergence)
        loss = np.sum(P * np.log(P / Q))
        losses.append(loss)
        
        if iteration % 100 == 0:
            print(f"  Iteration {iteration}: KL divergence = {loss:.4f}")
    
    print(f"\nOptimization complete! Final loss: {losses[-1]:.4f}")
    return Y, losses

## Example 1: Digits Dataset

Let's apply t-SNE to the classic digits dataset - 64-dimensional handwritten digit images.

In [None]:
# Load and prepare data
digits = load_digits()
X_digits = digits.data
y_digits = digits.target

# Standardize
scaler = StandardScaler()
X_digits = scaler.fit_transform(X_digits)

# Use subset for faster computation
n_samples = 500
indices = np.random.choice(len(X_digits), n_samples, replace=False)
X_subset = X_digits[indices]
y_subset = y_digits[indices]

print(f"Dataset shape: {X_subset.shape}")
print(f"Number of classes: {len(np.unique(y_subset))}")

In [None]:
# Run t-SNE
Y_embedded, losses = tsne(
    X_subset, 
    n_components=2, 
    perplexity=30.0, 
    n_iter=1000,
    learning_rate=200.0
)

In [None]:
# Visualize results
fig, axes = plt.subplots(1, 2, figsize=(15, 6))

# Plot 1: Loss curve
axes[0].plot(losses, linewidth=2)
axes[0].set_xlabel('Iteration', fontsize=12)
axes[0].set_ylabel('KL Divergence', fontsize=12)
axes[0].set_title('t-SNE Optimization Progress', fontsize=14, fontweight='bold')
axes[0].grid(True, alpha=0.3)

# Plot 2: Embedded points
scatter = axes[1].scatter(
    Y_embedded[:, 0], 
    Y_embedded[:, 1], 
    c=y_subset, 
    cmap='tab10',
    s=50,
    alpha=0.7,
    edgecolors='black',
    linewidth=0.5
)
axes[1].set_xlabel('t-SNE Dimension 1', fontsize=12)
axes[1].set_ylabel('t-SNE Dimension 2', fontsize=12)
axes[1].set_title('Digits Dataset - t-SNE Embedding', fontsize=14, fontweight='bold')
plt.colorbar(scatter, ax=axes[1], label='Digit Class')

plt.tight_layout()
plt.show()

## Example 2: Synthetic Data - Clustered Structure

Let's create synthetic data with clear cluster structure to see how t-SNE preserves it.

In [None]:
# Create synthetic clustered data
X_synthetic, y_synthetic = make_classification(
    n_samples=400,
    n_features=20,
    n_informative=15,
    n_redundant=5,
    n_classes=4,
    n_clusters_per_class=1,
    class_sep=2.0,
    random_state=42
)

# Standardize
X_synthetic = StandardScaler().fit_transform(X_synthetic)

print(f"Synthetic data shape: {X_synthetic.shape}")
print(f"Number of classes: {len(np.unique(y_synthetic))}")

In [None]:
# Run t-SNE on synthetic data
Y_synthetic, losses_synthetic = tsne(
    X_synthetic, 
    n_components=2, 
    perplexity=30.0, 
    n_iter=1000,
    learning_rate=200.0
)

In [None]:
# Visualize synthetic data results
fig, axes = plt.subplots(1, 2, figsize=(15, 6))

# Plot 1: Loss curve
axes[0].plot(losses_synthetic, linewidth=2, color='darkred')
axes[0].set_xlabel('Iteration', fontsize=12)
axes[0].set_ylabel('KL Divergence', fontsize=12)
axes[0].set_title('t-SNE Optimization Progress (Synthetic Data)', fontsize=14, fontweight='bold')
axes[0].grid(True, alpha=0.3)

# Plot 2: Embedded points
scatter = axes[1].scatter(
    Y_synthetic[:, 0], 
    Y_synthetic[:, 1], 
    c=y_synthetic, 
    cmap='viridis',
    s=60,
    alpha=0.7,
    edgecolors='black',
    linewidth=0.5
)
axes[1].set_xlabel('t-SNE Dimension 1', fontsize=12)
axes[1].set_ylabel('t-SNE Dimension 2', fontsize=12)
axes[1].set_title('Synthetic Data - t-SNE Embedding', fontsize=14, fontweight='bold')
plt.colorbar(scatter, ax=axes[1], label='Class')

plt.tight_layout()
plt.show()

## Understanding the Key Parameters

### 1. Perplexity
- Controls effective number of neighbors considered for each point
- Typical range: 5-50
- Lower perplexity → focuses on very local structure
- Higher perplexity → focuses on more global structure

### 2. Learning Rate
- Controls step size in gradient descent
- Typical range: 10-1000
- Too low → slow convergence
- Too high → unstable optimization

### 3. Early Exaggeration
- Multiplies P by 4 in first 250 iterations
- Creates tight clusters early on
- Helps separate clusters before fine-tuning

### 4. Momentum
- Helps escape local minima
- Smooths the optimization trajectory
- Typical value: 0.5-0.9

## Comparison with Different Perplexity Values

In [None]:
# Compare different perplexity values
perplexities = [5, 30, 50]
embeddings = []

# Use smaller subset for faster comparison
n_compare = 300
X_compare = X_digits[:n_compare]
y_compare = y_digits[:n_compare]

for perp in perplexities:
    print(f"\n{'='*50}")
    print(f"Running t-SNE with perplexity = {perp}")
    print(f"{'='*50}")
    Y_perp, _ = tsne(
        X_compare, 
        n_components=2, 
        perplexity=perp, 
        n_iter=500,  # Fewer iterations for comparison
        learning_rate=200.0
    )
    embeddings.append(Y_perp)

In [None]:
# Visualize different perplexity results
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for idx, (perp, Y_emb) in enumerate(zip(perplexities, embeddings)):
    scatter = axes[idx].scatter(
        Y_emb[:, 0], 
        Y_emb[:, 1], 
        c=y_compare, 
        cmap='tab10',
        s=40,
        alpha=0.7,
        edgecolors='black',
        linewidth=0.5
    )
    axes[idx].set_xlabel('t-SNE Dimension 1', fontsize=11)
    axes[idx].set_ylabel('t-SNE Dimension 2', fontsize=11)
    axes[idx].set_title(f'Perplexity = {perp}', fontsize=13, fontweight='bold')
    if idx == 2:
        plt.colorbar(scatter, ax=axes[idx], label='Digit')

plt.suptitle('Effect of Perplexity on t-SNE Embeddings', fontsize=15, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

## Key Takeaways

1. **t-SNE is stochastic** - different runs produce different results (but similar structure)

2. **Distances in t-SNE plots are not meaningful** - only relative positions within clusters matter

3. **t-SNE preserves local structure** - points close in high-D stay close in low-D

4. **Computational complexity is O(N²)** - doesn't scale well to very large datasets
   - For larger datasets, use approximations like Barnes-Hut t-SNE

5. **Choose perplexity based on dataset size**:
   - Small datasets (N < 500): perplexity = 5-15
   - Medium datasets (500 < N < 5000): perplexity = 30-50
   - Larger datasets: consider higher values

6. **Early exaggeration helps** - it creates tighter initial clusters that can separate better

7. **Run for enough iterations** - typically 1000+ to ensure convergence