# Self-Organizing Map (SOM) Visualization

This notebook focuses on Self-Organizing Maps for topographical visualization and analysis.

## Overview
- Train and visualize Self-Organizing Maps
- Explore U-matrix, component planes, and hit maps
- Analyze cluster structure in SOM grids
- Track topographic and quantization errors

In [None]:
# Setup and imports
import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent))

import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib import cm
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

# Configure plotting
plt.style.use('seaborn-v0_8-whitegrid')
%matplotlib inline

# Import PoT modules
from pot.semantic.topography import SOMProjector
from pot.semantic.topography_utils import (
    prepare_latents_for_projection,
    identify_clusters_in_projection
)
from pot.semantic.topography_visualizer import (
    plot_som_grid,
    plot_som_u_matrix,
    plot_som_component_planes,
    plot_som_hit_map,
    plot_som_clusters
)

print("Setup complete!")

## 1. Generate Sample Data

Create data with interesting structure for SOM training.

In [None]:
def generate_structured_data(n_samples=1000):
    """Generate data with multiple structural patterns."""
    np.random.seed(42)
    
    data_parts = []
    labels = []
    
    # Gaussian clusters
    for i in range(3):
        center = np.random.randn(10) * 3
        cluster = np.random.randn(n_samples // 4, 10) + center
        data_parts.append(cluster)
        labels.extend([i] * len(cluster))
    
    # Linear manifold
    t = np.linspace(0, 4*np.pi, n_samples // 4)
    linear = np.column_stack([
        np.sin(t), np.cos(t), t/10,
        np.sin(2*t), np.cos(2*t),
        np.random.randn(len(t), 5) * 0.1
    ])
    data_parts.append(linear)
    labels.extend([3] * len(linear))
    
    data = np.vstack(data_parts)
    labels = np.array(labels)
    
    # Shuffle
    indices = np.random.permutation(len(data))
    data = data[indices]
    labels = labels[indices]
    
    return data, labels

# Generate data
data, labels = generate_structured_data(800)
print(f"Data shape: {data.shape}")
print(f"Unique labels: {np.unique(labels)}")

## 2. Train Self-Organizing Map

In [None]:
# Create and train SOM
som_projector = SOMProjector(
    grid_size=(20, 20),
    learning_rate=0.5,
    sigma=2.0,
    topology='hexagonal'
)

print("Training SOM...")
som_projector.train(data, num_iteration=200)
print("Training complete!")

# Get projections
projected = som_projector.project(data)
print(f"Projected shape: {projected.shape}")

# Compute errors
qe = som_projector.quantization_error(data)
te = som_projector.topographic_error(data)
print(f"\nQuantization Error: {qe:.4f}")
print(f"Topographic Error: {te:.4f}")

## 3. Basic SOM Visualization

In [None]:
# Plot projected data
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Scatter plot of projections
ax = axes[0]
scatter = ax.scatter(projected[:, 0], projected[:, 1], 
                     c=labels, cmap='tab10', alpha=0.6, s=20)
ax.set_title("SOM Projection")
ax.set_xlabel("SOM X")
ax.set_ylabel("SOM Y")
ax.grid(True, alpha=0.3)
plt.colorbar(scatter, ax=ax, label='Cluster')

# Hexbin plot for density
ax = axes[1]
hexbin = ax.hexbin(projected[:, 0], projected[:, 1], 
                   gridsize=15, cmap='YlOrRd')
ax.set_title("Density Map")
ax.set_xlabel("SOM X")
ax.set_ylabel("SOM Y")
plt.colorbar(hexbin, ax=ax, label='Count')

plt.tight_layout()
plt.show()

## 4. U-Matrix Visualization

The U-Matrix shows the distances between neighboring neurons, revealing cluster boundaries.

In [None]:
if hasattr(som_projector, 'som') and som_projector.som is not None:
    # Plot U-Matrix
    fig = plot_som_u_matrix(
        som_projector.som,
        title="SOM U-Matrix (Distance Map)"
    )
    plt.show()
    
    print("Dark regions indicate cluster boundaries (large distances between neurons)")
    print("Light regions indicate clusters (small distances between neurons)")
else:
    print("SOM not available for detailed visualization")

## 5. Component Planes

Visualize how each feature is distributed across the SOM grid.

In [None]:
if hasattr(som_projector, 'som') and som_projector.som is not None:
    # Select a few components to visualize
    n_components_to_show = min(6, data.shape[1])
    
    fig = plot_som_component_planes(
        som_projector.som,
        component_indices=list(range(n_components_to_show)),
        component_names=[f"Feature {i+1}" for i in range(n_components_to_show)],
        title="SOM Component Planes"
    )
    plt.show()
    
    print("Each subplot shows how one feature varies across the SOM grid")
else:
    # Manual component plane visualization
    print("Creating manual component visualization...")
    
    # Group data by grid position
    grid_size = 20
    grid_values = np.zeros((grid_size, grid_size, data.shape[1]))
    grid_counts = np.zeros((grid_size, grid_size))
    
    for i, (x, y) in enumerate(projected):
        grid_x = int(x)
        grid_y = int(y)
        if 0 <= grid_x < grid_size and 0 <= grid_y < grid_size:
            grid_values[grid_y, grid_x] += data[i]
            grid_counts[grid_y, grid_x] += 1
    
    # Average values
    mask = grid_counts > 0
    grid_values[mask] /= grid_counts[mask, np.newaxis]
    
    # Plot first few components
    fig, axes = plt.subplots(2, 3, figsize=(12, 8))
    axes = axes.flatten()
    
    for i in range(min(6, data.shape[1])):
        im = axes[i].imshow(grid_values[:, :, i], cmap='coolwarm', aspect='auto')
        axes[i].set_title(f"Feature {i+1}")
        plt.colorbar(im, ax=axes[i])
    
    plt.tight_layout()
    plt.show()

## 6. Hit Map and Cluster Analysis

In [None]:
# Create hit map manually
grid_size = 20
hit_map = np.zeros((grid_size, grid_size))

for x, y in projected:
    grid_x = int(x)
    grid_y = int(y)
    if 0 <= grid_x < grid_size and 0 <= grid_y < grid_size:
        hit_map[grid_y, grid_x] += 1

# Plot hit map
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Hit map
ax = axes[0]
im = ax.imshow(hit_map, cmap='hot', aspect='auto')
ax.set_title("SOM Hit Map")
ax.set_xlabel("SOM X")
ax.set_ylabel("SOM Y")
plt.colorbar(im, ax=ax, label='Number of Hits')

# Cluster assignments on grid
ax = axes[1]
cluster_grid = np.full((grid_size, grid_size), -1, dtype=float)

for i, (x, y) in enumerate(projected):
    grid_x = int(x)
    grid_y = int(y)
    if 0 <= grid_x < grid_size and 0 <= grid_y < grid_size:
        cluster_grid[grid_y, grid_x] = labels[i]

# Mask empty cells
masked_grid = np.ma.masked_where(cluster_grid < 0, cluster_grid)

im = ax.imshow(masked_grid, cmap='tab10', aspect='auto')
ax.set_title("Cluster Distribution on SOM")
ax.set_xlabel("SOM X")
ax.set_ylabel("SOM Y")
plt.colorbar(im, ax=ax, label='Cluster ID')

plt.tight_layout()
plt.show()

print(f"Total hits: {int(np.sum(hit_map))}")
print(f"Max hits in single cell: {int(np.max(hit_map))}")
print(f"Cells with hits: {np.sum(hit_map > 0)} / {grid_size * grid_size}")

## 7. Quality Analysis Over Training

In [None]:
# Train multiple SOMs with different parameters
param_variations = [
    {'learning_rate': 0.1, 'sigma': 1.0},
    {'learning_rate': 0.5, 'sigma': 2.0},
    {'learning_rate': 1.0, 'sigma': 3.0},
]

results = []

for params in param_variations:
    som = SOMProjector(grid_size=(15, 15), **params)
    som.train(data[:500], num_iteration=100)  # Use subset for speed
    
    qe = som.quantization_error(data[:500])
    te = som.topographic_error(data[:500])
    
    results.append({
        'params': params,
        'quantization_error': qe,
        'topographic_error': te
    })

# Plot results
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Quantization error
ax = axes[0]
qe_values = [r['quantization_error'] for r in results]
labels_param = [f"lr={r['params']['learning_rate']}, σ={r['params']['sigma']}" 
                for r in results]
bars = ax.bar(range(len(results)), qe_values, color='steelblue')
ax.set_xticks(range(len(results)))
ax.set_xticklabels(labels_param, rotation=45)
ax.set_ylabel("Quantization Error")
ax.set_title("Quantization Error vs Parameters")
ax.grid(True, alpha=0.3)

# Topographic error
ax = axes[1]
te_values = [r['topographic_error'] for r in results]
bars = ax.bar(range(len(results)), te_values, color='coral')
ax.set_xticks(range(len(results)))
ax.set_xticklabels(labels_param, rotation=45)
ax.set_ylabel("Topographic Error")
ax.set_title("Topographic Error vs Parameters")
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print best parameters
best_qe_idx = np.argmin(qe_values)
best_te_idx = np.argmin(te_values)

print(f"Best Quantization Error: {results[best_qe_idx]['params']}")
print(f"Best Topographic Error: {results[best_te_idx]['params']}")

## 8. Compare SOM with Other Projections

In [None]:
from pot.semantic import TopographicalProjector
from pot.semantic.topography_utils import (
    compute_trustworthiness,
    compute_continuity
)

# Compare SOM with PCA and UMAP
methods = ['som', 'pca', 'umap']
projections_compare = {}
metrics_compare = {}

subset_data = data[:400]  # Use subset for fair comparison
subset_labels = labels[:400]

for method in methods:
    print(f"\nProjecting with {method.upper()}...")
    
    if method == 'som':
        # Use our already trained SOM
        proj = som_projector.project(subset_data)
    else:
        projector = TopographicalProjector(method)
        proj = projector.project_latents(torch.tensor(subset_data, dtype=torch.float32))
    
    projections_compare[method] = proj
    
    # Compute metrics
    trust = compute_trustworthiness(subset_data, proj, n_neighbors=10)
    cont = compute_continuity(subset_data, proj, n_neighbors=10)
    
    metrics_compare[method] = {'trust': trust, 'cont': cont}
    print(f"  Trustworthiness: {trust:.3f}")
    print(f"  Continuity: {cont:.3f}")

# Visualize comparisons
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

for idx, (method, proj) in enumerate(projections_compare.items()):
    ax = axes[idx]
    scatter = ax.scatter(proj[:, 0], proj[:, 1], 
                        c=subset_labels, cmap='tab10', 
                        alpha=0.6, s=20)
    
    metrics = metrics_compare[method]
    ax.set_title(f"{method.upper()}\nT={metrics['trust']:.2f}, C={metrics['cont']:.2f}")
    ax.set_xlabel("Component 1")
    ax.set_ylabel("Component 2")
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 9. SOM for Anomaly Detection

In [None]:
# Use SOM for anomaly detection
# Points far from their BMU (Best Matching Unit) are potential anomalies

# Get BMU distances
bmu_distances = []

if hasattr(som_projector, 'som') and som_projector.som is not None:
    for sample in data:
        # Find BMU
        bmu = som_projector.som.winner(sample)
        # Get weight vector of BMU
        bmu_weight = som_projector.som.get_weights()[bmu]
        # Compute distance
        dist = np.linalg.norm(sample - bmu_weight)
        bmu_distances.append(dist)
else:
    # Approximate using projected positions
    for i, sample in enumerate(data):
        # Find nearest grid point
        grid_x, grid_y = int(projected[i, 0]), int(projected[i, 1])
        
        # Find other points at same grid position
        same_cell = []
        for j, (x, y) in enumerate(projected):
            if int(x) == grid_x and int(y) == grid_y and i != j:
                same_cell.append(data[j])
        
        if same_cell:
            # Distance to mean of cell
            cell_mean = np.mean(same_cell, axis=0)
            dist = np.linalg.norm(sample - cell_mean)
        else:
            dist = 0
        
        bmu_distances.append(dist)

bmu_distances = np.array(bmu_distances)

# Identify anomalies (top 5% distances)
threshold = np.percentile(bmu_distances, 95)
anomalies = bmu_distances > threshold

# Visualize
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Distance distribution
ax = axes[0]
ax.hist(bmu_distances, bins=50, alpha=0.7, color='steelblue', edgecolor='black')
ax.axvline(threshold, color='red', linestyle='--', label=f'Threshold (95%): {threshold:.2f}')
ax.set_xlabel("Distance to BMU")
ax.set_ylabel("Count")
ax.set_title("BMU Distance Distribution")
ax.legend()
ax.grid(True, alpha=0.3)

# Anomalies on projection
ax = axes[1]
ax.scatter(projected[~anomalies, 0], projected[~anomalies, 1], 
          c='steelblue', alpha=0.5, s=20, label='Normal')
ax.scatter(projected[anomalies, 0], projected[anomalies, 1], 
          c='red', alpha=0.8, s=50, marker='^', label='Anomaly')
ax.set_title("Anomalies in SOM Projection")
ax.set_xlabel("SOM X")
ax.set_ylabel("SOM Y")
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Number of anomalies detected: {np.sum(anomalies)} / {len(data)}")
print(f"Anomaly rate: {np.mean(anomalies)*100:.1f}%")

## Summary

This notebook explored Self-Organizing Maps (SOMs) for topographical visualization:

### Key Concepts:
1. **SOM Training**: Unsupervised learning to create a 2D grid representation
2. **U-Matrix**: Visualizes distances between neurons to reveal cluster boundaries
3. **Component Planes**: Shows how individual features vary across the map
4. **Hit Map**: Indicates data density across the grid
5. **Quality Metrics**:
   - **Quantization Error**: Average distance from data to BMUs
   - **Topographic Error**: Proportion of data with non-adjacent BMUs

### Applications:
- **Clustering**: Natural grouping emerges on the 2D grid
- **Visualization**: Intuitive 2D representation of high-dimensional data
- **Anomaly Detection**: Points far from their BMUs are outliers
- **Feature Analysis**: Component planes reveal feature relationships

### Advantages of SOMs:
- Topology preservation
- Interpretable grid structure
- Can handle large datasets
- Natural clustering without specifying k

### When to Use SOMs:
- Exploratory data analysis
- When topology preservation is important
- For creating interpretable 2D maps
- When you need both clustering and visualization