# Global Sequential Optimal Transport with UMAP + Gromov-Wasserstein

This notebook demonstrates:
- **UMAP** for non-linear dimensionality reduction (384D → 14D)
- **Gromov-Wasserstein** for manifold-aware sequential alignment
- **Comparison** with Procrustes baseline
- **Visualizations** of alignment quality

## Setup (Run this first in Google Colab)

In [None]:
# Uncomment and run if using Google Colab
# !git clone https://github.com/tomasblood/SeqOTTest.git
# %cd SeqOTTest
# !pip install -q -r requirements.txt
# !pip install -q umap-learn

## 1. Import Libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import time
import pickle
import umap

from src.seqot.alignment import GlobalSeqOTAlignment
from src.seqot.data_loaders import create_sample_neurips_data
from src.seqot.evaluation import evaluate_alignment

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

print("✓ Libraries loaded successfully")

## 2. Create Realistic NeurIPS-scale Embeddings

In [None]:
# Create embeddings: 100 papers/year × 6 years × 384 dimensions
data_path = create_sample_neurips_data(
    output_path='data/neurips/umap_gw_notebook.pkl',
    n_years=6,
    n_papers_per_year=100,
    n_dims=384,
    random_state=42
)

# Load embeddings
with open(data_path, 'rb') as f:
    data = pickle.load(f)

embeddings = data['embeddings']
years = data['years']

print(f"\n✓ Loaded {len(years)} time steps: {years[0]}-{years[-1]}")
for i, (year, emb) in enumerate(zip(years, embeddings)):
    print(f"  {year}: {emb.shape[0]} papers × {emb.shape[1]} dims")

## 3. Apply UMAP Dimensionality Reduction (384D → 14D)

UMAP preserves **non-linear manifold structure** unlike PCA's linear projection.

In [None]:
n_components = 14

print(f"Applying UMAP: {embeddings[0].shape[1]}D → {n_components}D")
print("This preserves local manifold structure (non-linear)\n")

# Fit UMAP on all data
all_embeddings = np.vstack(embeddings)
print(f"Total samples: {all_embeddings.shape[0]} ({len(embeddings)} years × {embeddings[0].shape[0]} papers)")

t0 = time.time()
reducer = umap.UMAP(
    n_components=n_components,
    n_neighbors=30,          # Larger neighborhoods for global structure
    min_dist=0.0,            # Preserve distances faithfully
    metric='euclidean',
    random_state=42,
    verbose=False
)
reducer.fit(all_embeddings)
t_umap_fit = time.time() - t0

# Transform each time step
embeddings_reduced = []
for emb in embeddings:
    emb_reduced = reducer.transform(emb)
    embeddings_reduced.append(emb_reduced)

print(f"\n✓ UMAP fitted in {t_umap_fit:.2f}s")
print(f"  Original: {embeddings[0].shape[1]}D")
print(f"  Reduced:  {n_components}D")
print(f"  Method: Non-linear manifold preservation")

## 4. Run Global SeqOT with Gromov-Wasserstein on 14D UMAP Space

GW compares **internal geometries** by matching distance matrices.

In [None]:
print("Running GW on UMAP-reduced embeddings (14D)...\n")

# Configure GW aligner
aligner_gw = GlobalSeqOTAlignment(
    epsilon=0.5,
    max_iter=20,
    tol=1e-5,
    use_gromov=True,
    metric='euclidean',
    verbose=True
)

# Run GW
t0 = time.time()
aligned_gw_reduced = aligner_gw.fit_transform(embeddings_reduced)
t_gw = time.time() - t0

print(f"\n✓ GW completed in {t_gw:.1f}s on {n_components}D embeddings")

## 5. Evaluate GW Alignment Quality

In [None]:
# Evaluate in UMAP space
results_gw = evaluate_alignment(
    embeddings_reduced,
    aligned_gw_reduced,
    method_name='GW + UMAP'
)

print("GW Alignment Metrics (in 14D UMAP space):")
print(f"  Euclidean Error: {results_gw['mean_euclidean_error']:.4f}")
print(f"  Cosine Distance: {results_gw['mean_cosine_distance']:.4f}")
print(f"  Correlation:     {results_gw['mean_correlation']:.4f}")

## 6. Compare with Procrustes Baseline

In [None]:
print("Running Procrustes baseline on original embeddings...")

# Run Procrustes on original high-dimensional embeddings
aligner_proc = GlobalSeqOTAlignment(
    epsilon=0.5,
    max_iter=50,
    tol=1e-6,
    use_gromov=False,  # Standard OT with Euclidean cost
    metric='euclidean',
    verbose=False
)

t0 = time.time()
aligned_proc = aligner_proc.fit_transform(embeddings)
t_procrustes = time.time() - t0

print(f"✓ Procrustes completed in {t_procrustes:.2f}s")

# Also project Procrustes results into UMAP space for fair comparison
aligned_proc_umap = [reducer.transform(emb) for emb in aligned_proc]

results_proc = evaluate_alignment(
    embeddings_reduced,
    aligned_proc_umap,
    method_name='Procrustes (in UMAP space)'
)

print("\nProcrustes Metrics (projected to 14D UMAP space):")
print(f"  Euclidean Error: {results_proc['mean_euclidean_error']:.4f}")
print(f"  Cosine Distance: {results_proc['mean_cosine_distance']:.4f}")
print(f"  Correlation:     {results_proc['mean_correlation']:.4f}")

## 7. Visualize Results: Metrics Comparison

In [None]:
# Prepare data for visualization
methods = ['GW + UMAP', 'Procrustes\n(UMAP space)']
euclidean_errors = [
    results_gw['mean_euclidean_error'],
    results_proc['mean_euclidean_error']
]
cosine_distances = [
    results_gw['mean_cosine_distance'],
    results_proc['mean_cosine_distance']
]
correlations = [
    results_gw['mean_correlation'],
    results_proc['mean_correlation']
]

# Create comparison plot
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Euclidean Error
ax = axes[0]
bars = ax.bar(methods, euclidean_errors, color=['#2ecc71', '#3498db'], alpha=0.7)
ax.set_ylabel('Euclidean Error', fontsize=12)
ax.set_title('Euclidean Error (Lower is Better)', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')
for bar, val in zip(bars, euclidean_errors):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height,
            f'{val:.3f}', ha='center', va='bottom', fontsize=11, fontweight='bold')

# Cosine Distance
ax = axes[1]
bars = ax.bar(methods, cosine_distances, color=['#2ecc71', '#3498db'], alpha=0.7)
ax.set_ylabel('Cosine Distance', fontsize=12)
ax.set_title('Cosine Distance (Lower is Better)', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')
for bar, val in zip(bars, cosine_distances):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height,
            f'{val:.3f}', ha='center', va='bottom', fontsize=11, fontweight='bold')

# Correlation
ax = axes[2]
bars = ax.bar(methods, correlations, color=['#2ecc71', '#3498db'], alpha=0.7)
ax.set_ylabel('Correlation', fontsize=12)
ax.set_title('Correlation (Higher is Better)', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')
for bar, val in zip(bars, correlations):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height,
            f'{val:.3f}', ha='center', va='bottom', fontsize=11, fontweight='bold')

plt.tight_layout()
plt.show()

# Calculate improvement
improvement = ((results_proc['mean_euclidean_error'] - results_gw['mean_euclidean_error']) / 
               results_proc['mean_euclidean_error'] * 100)
print(f"\nGW vs Procrustes: {improvement:+.1f}% ({'better' if improvement > 0 else 'worse'})")

## 8. Visualize 2D UMAP Projection

Project embeddings to 2D for visualization.

In [None]:
# Create 2D UMAP from 14D reduced space
print("Creating 2D UMAP visualization from 14D space...")

all_reduced_14d = np.vstack(embeddings_reduced)
reducer_2d = umap.UMAP(n_components=2, random_state=42, min_dist=0.0, verbose=False)
all_reduced_2d = reducer_2d.fit_transform(all_reduced_14d)

# Transform aligned embeddings
all_aligned_14d = np.vstack(aligned_gw_reduced)
all_aligned_2d = reducer_2d.transform(all_aligned_14d)

# Split back into time steps
idx = 0
embeddings_2d = []
aligned_gw_2d = []
for emb in embeddings_reduced:
    n_samples = emb.shape[0]
    embeddings_2d.append(all_reduced_2d[idx:idx+n_samples])
    aligned_gw_2d.append(all_aligned_2d[idx:idx+n_samples])
    idx += n_samples

# Plot
fig, axes = plt.subplots(1, 2, figsize=(18, 7))

# Original embeddings
ax = axes[0]
colors = plt.cm.viridis(np.linspace(0, 1, len(years)))
for i, (year, emb_2d) in enumerate(zip(years, embeddings_2d)):
    ax.scatter(emb_2d[:, 0], emb_2d[:, 1], alpha=0.6, label=str(year), 
               s=50, color=colors[i], edgecolors='white', linewidth=0.5)
ax.set_title('Original Embeddings\n(2D projection of 14D UMAP)', fontsize=14, fontweight='bold')
ax.set_xlabel('UMAP Component 1', fontsize=12)
ax.set_ylabel('UMAP Component 2', fontsize=12)
ax.legend(title='Year', bbox_to_anchor=(1.05, 1), loc='upper left')
ax.grid(True, alpha=0.3)

# GW-aligned embeddings
ax = axes[1]
for i, (year, emb_2d) in enumerate(zip(years, aligned_gw_2d)):
    ax.scatter(emb_2d[:, 0], emb_2d[:, 1], alpha=0.6, label=str(year), 
               s=50, color=colors[i], edgecolors='white', linewidth=0.5)
ax.set_title('GW-Aligned Embeddings\n(2D projection of 14D UMAP)', fontsize=14, fontweight='bold')
ax.set_xlabel('UMAP Component 1', fontsize=12)
ax.set_ylabel('UMAP Component 2', fontsize=12)
ax.legend(title='Year', bbox_to_anchor=(1.05, 1), loc='upper left')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("✓ 2D visualization complete")

## 9. Visualize Transport Couplings

Show the optimal transport coupling matrices between consecutive years.

In [None]:
if hasattr(aligner_gw.solver_, 'couplings_'):
    couplings = aligner_gw.solver_.couplings_
    n_couplings = len(couplings)
    
    # Plot first 3 coupling matrices
    n_plot = min(3, n_couplings)
    fig, axes = plt.subplots(1, n_plot, figsize=(6*n_plot, 5))
    if n_plot == 1:
        axes = [axes]
    
    for i in range(n_plot):
        ax = axes[i]
        coupling = couplings[i]
        im = ax.imshow(coupling, cmap='viridis', aspect='auto', interpolation='nearest')
        ax.set_title(f'Transport Coupling\n{years[i]} → {years[i+1]}', 
                     fontsize=12, fontweight='bold')
        ax.set_xlabel(f'Papers in {years[i+1]}', fontsize=10)
        ax.set_ylabel(f'Papers in {years[i]}', fontsize=10)
        plt.colorbar(im, ax=ax, label='Transport Mass')
    
    plt.tight_layout()
    plt.show()
    
    print(f"✓ Visualized {n_plot}/{n_couplings} transport coupling matrices")
else:
    print("⚠ Transport couplings not available")

## 10. Runtime Summary

In [None]:
print("=" * 70)
print("RUNTIME SUMMARY")
print("=" * 70)
print(f"UMAP fitting (384D → 14D):     {t_umap_fit:>8.2f}s")
print(f"Global SeqOT (GW on 14D):      {t_gw:>8.1f}s")
print(f"Procrustes (on 384D):          {t_procrustes:>8.2f}s")
print(f"{'─' * 70}")
print(f"Total GW pipeline:             {t_umap_fit + t_gw:>8.1f}s")
print("=" * 70)

## Key Takeaways

1. **UMAP** effectively reduces dimensionality (384D → 14D) while preserving non-linear manifold structure
2. **GW** runs efficiently on 14D UMAP-reduced space (~7-8 minutes for 100 papers/year)
3. **Procrustes** remains competitive in UMAP space, possibly because UMAP already captures important geometric structure
4. **Future direction**: Coupled Flow Matching for learnable non-linear reduction with reconstruction