In [None]:
# Install required packages (uncomment if running for the first time)
# %pip install holoviews bokeh scikit-learn

# For this demo, we'll install the missing packages directly
%uv pip install -e .

# FOSS4G 2025 Demo: TerraMind Embedding Generation with odc-stac

This notebook demonstrates the complete workflow for generating geospatial embeddings from satellite imagery:

1. **Load satellite data** from STAC catalogs using odc-stac
2. **Process RGB composites** for foundation model input
3. **Load TerraMind model** (or fallback models) with TerraTorch
4. **Generate embeddings** from 16x16 pixel patches
5. **Visualize embeddings** in 3D space using dimensionality reduction

## Key Technologies

- **odc-stac**: Load STAC items into xarray Datasets
- **TerraTorch**: Foundation model integration and training toolkit
- **TerraMind**: IBM's geospatial foundation model (768-dim embeddings)
- **Element84 Earth Search**: AWS-hosted STAC catalog for satellite data
- **HoloViews**: Interactive 3D visualization of embedding space

## 1. Import Required Libraries

Import all necessary libraries for our TerraMind embedding generation workflow.

In [None]:
# Install required packages (uncomment if running for the first time)
# !pip install odc-stac terratorch pystac-client xarray rasterio matplotlib
# !pip install holoviews bokeh scikit-learn

import warnings
warnings.filterwarnings("ignore")

import json
import logging
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import odc.stac

# STAC and data loading
import pystac_client

# TerraTorch and ML
try:
    import torch
    # Try multiple import patterns for BACKBONE_REGISTRY
    try:
        from terratorch.models.backbones import BACKBONE_REGISTRY
        print("‚úÖ TerraTorch BACKBONE_REGISTRY imported successfully")
    except ImportError:
        try:
            from terratorch import BACKBONE_REGISTRY
            print("‚úÖ TerraTorch BACKBONE_REGISTRY imported from main module")
        except ImportError:
            print("‚ö†Ô∏è BACKBONE_REGISTRY not found, will use fallback models")
            BACKBONE_REGISTRY = None
    
    print("‚úÖ TerraTorch imported successfully")
except ImportError as e:
    print(f"‚ö†Ô∏è TerraTorch import issue: {e}")
    BACKBONE_REGISTRY = None

# Visualization libraries
try:
    import holoviews as hv
    hv.extension("bokeh")
    HV_AVAILABLE = True
    print("‚úÖ HoloViews imported successfully")
except ImportError as e:
    print(f"‚ö†Ô∏è HoloViews not available: {e}")
    print("üìä Will use matplotlib for visualization instead")
    HV_AVAILABLE = False

# ML utilities
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.metrics.pairwise import cosine_similarity

# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

print("üöÄ All available libraries imported successfully!")

## 2. Connect to STAC Catalog

Connect to Element84 Earth Search STAC catalog for satellite data discovery.

In [None]:
# Configuration
STAC_URL = "https://earth-search.aws.element84.com/v1"
COLLECTION = "sentinel-2-l2a"

# Auckland, New Zealand - demo area
BBOX = [174.6, -36.95, 174.85, -36.75]
DATETIME = "2023-12-01/2023-12-31"
BANDS = ["red", "green", "blue", "nir"]

# Connect to STAC catalog
logger.info(f"Connecting to STAC catalog: {STAC_URL}")
catalog = pystac_client.Client.open(STAC_URL)
print(f"‚úÖ Connected to {catalog.title}")

# Display catalog information
print(f"üìç Catalog URL: {STAC_URL}")
print(f"üóÇÔ∏è Available collections: {len(list(catalog.get_collections()))}")
print(f"üéØ Target collection: {COLLECTION}")
print(f"üì¶ Area of Interest: {BBOX} (Auckland, NZ)")

## 3. Search and Load Satellite Data

Search for Sentinel-2 imagery and load it using odc-stac.

In [None]:
# Search for Sentinel-2 data
logger.info(f"Searching for {COLLECTION} data...")
search = catalog.search(
    collections=[COLLECTION],
    datetime=DATETIME,
    bbox=BBOX,
    limit=10,
    query={"eo:cloud_cover": {"lt": 50}},  # Increased cloud cover threshold
)

# Get search results
items = list(search.items())
print(f"üîç Found {len(items)} items with <50% cloud cover")

# If no items found, try with relaxed constraints
if len(items) == 0:
    print("‚ö†Ô∏è No items found, trying with relaxed constraints...")
    search = catalog.search(
        collections=[COLLECTION],
        datetime="2023-06-01/2023-08-31",  # Try summer period
        bbox=BBOX,
        limit=10,
        query={"eo:cloud_cover": {"lt": 80}},
    )
    items = list(search.items())
    print(f"üîç Found {len(items)} items with relaxed criteria")

if len(items) == 0:
    raise ValueError("No suitable Sentinel-2 data found for the specified region and time period")

# Load data using odc-stac
logger.info("Loading data with odc-stac...")
dataset = odc.stac.load(
    items,
    bands=BANDS,
    resolution=100,  # 100m resolution for demo
    chunks={"time": 1, "x": 512, "y": 512},
    groupby="solar_day",
)

print(f"‚úÖ Loaded dataset with shape: {dict(dataset.dims)}")
print(f"üìä Data variables: {list(dataset.data_vars)}")
print(f"‚è∞ Time range: {dataset.time.values[0]} to {dataset.time.values[-1]}")

# Display basic info
_ = dataset  # Display dataset info

## 4. Create RGB Composite

Create RGB composite for visualization and model input.

In [None]:
def create_rgb_composite(dataset, time_index=-1):
    """Create RGB composite from dataset."""
    ds = dataset.isel(time=time_index) if "time" in dataset.dims else dataset

    # Stack RGB bands
    rgb = np.stack([ds.red, ds.green, ds.blue], axis=-1)

    # Convert to reflectance (Sentinel-2 values are scaled by 10000)
    rgb = rgb / 10000.0
    rgb = np.clip(rgb, 0, 1)

    return rgb


# Create RGB composite from most recent image
logger.info("Creating RGB composite...")
rgb_composite = create_rgb_composite(dataset, time_index=-1)

print(f"üì∏ RGB composite shape: {rgb_composite.shape}")
print(
    f"üìà Value range: [{np.nanmin(rgb_composite):.3f}, {np.nanmax(rgb_composite):.3f}]"
)

# Visualize RGB composite
plt.figure(figsize=(12, 8))
plt.imshow(rgb_composite)
plt.title(f"RGB Composite - Auckland, New Zealand\n{dataset.time.values[-1]}")
plt.axis("off")
plt.tight_layout()
plt.show()

# Store for embedding generation
rgb_array = rgb_composite

## 5. Load TerraMind Model

Load TerraMind foundation model with robust fallback system.

In [None]:
def load_terramind_model():
    """Load TerraMind model with fallback system."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Using device: {device}")

    # Model fallback chain
    models_to_try = [
        ("terramind_v1_base", "TerraMind foundation model"),
        ("clay_v1", "Clay foundation model"),
        ("prithvi_vit", "Prithvi Vision Transformer"),
        ("resnet18", "ResNet18 (timm)"),
    ]

    for model_name, description in models_to_try:
        try:
            logger.info(f"Attempting to load model: {model_name}")

            if model_name == "resnet18":
                # Special handling for timm models
                import timm
                model = timm.create_model("resnet18", pretrained=True, num_classes=0)
            else:
                # TerraTorch models
                if BACKBONE_REGISTRY is None:
                    raise ImportError("BACKBONE_REGISTRY not available")
                
                model = BACKBONE_REGISTRY.build(
                    model_name,
                    modalities=["S2RGB"] if "terra" in model_name or "clay" in model_name else None,
                    pretrained=True,
                )

            model = model.to(device)
            model.eval()

            logger.info(f"‚úÖ Successfully loaded {model_name}")
            print(f"ü§ñ Model: {description}")
            print(f"üì± Device: {device}")

            if model_name != "terramind_v1_base":
                print("‚ö†Ô∏è TerraMind not available, using fallback model")
                print("‚ö†Ô∏è Embeddings will be generated but may not be TerraMind-specific")

            return model

        except Exception as e:
            logger.info(f"Failed to load {model_name}: {e}")
            continue

    raise RuntimeError("Could not load any model from the fallback chain")


# Load model
try:
    model = load_terramind_model()
except Exception as e:
    print(f"‚ùå Error loading model: {e}")
    print("‚ö†Ô∏è Please ensure terratorch is installed: pip install terratorch")
    raise

## 6. Prepare Data for TerraMind

Extract patches and normalize for model input.

In [None]:
def rgb_smooth_quantiles(rgb_array, quantiles=None):
    """Apply smooth quantile normalization to RGB data."""
    if quantiles is None:
        quantiles = [0.02, 0.98]
    
    normalized = np.zeros_like(rgb_array)

    for i in range(3):  # RGB channels
        channel = rgb_array[:, :, i]
        valid_mask = ~np.isnan(channel)

        if valid_mask.any():
            q_low, q_high = np.quantile(channel[valid_mask], quantiles)
            normalized[:, :, i] = np.clip((channel - q_low) / (q_high - q_low), 0, 1)
        else:
            normalized[:, :, i] = channel

    return normalized


def prepare_terramind_patches(rgb_data, patch_size=16):
    """Extract 16x16 patches from RGB data."""
    height, width, channels = rgb_data.shape
    patches = []

    for y in range(0, height - patch_size + 1, patch_size):
        for x in range(0, width - patch_size + 1, patch_size):
            patch = rgb_data[y : y + patch_size, x : x + patch_size, :]
            if not np.isnan(patch).any():  # Skip patches with NaN values
                patches.append(patch)

    return np.array(patches)


def normalize_terramind_input(patches):
    """Normalize patches for model input."""
    # ImageNet normalization
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])

    patches_tensor = torch.from_numpy(patches).float()
    patches_tensor = patches_tensor.permute(0, 3, 1, 2)  # NHWC -> NCHW

    for i in range(3):
        patches_tensor[:, i] = (patches_tensor[:, i] - mean[i]) / std[i]

    return patches_tensor


# Process data
logger.info("Applying smooth quantile normalization...")
normalized_rgb = rgb_smooth_quantiles(rgb_array)

logger.info("Extracting 16x16 patches...")
patches = prepare_terramind_patches(normalized_rgb, patch_size=16)
print(f"üß© Extracted {len(patches)} patches")

logger.info("Normalizing patches for model input...")
patches_tensor = normalize_terramind_input(patches)
print(f"üì¶ Normalized tensor shape: {patches_tensor.shape}")

## 7. Generate Embeddings

Generate embeddings from processed patches using the loaded model.

In [None]:
def generate_embeddings(patches_tensor, model, batch_size=32):
    """Generate embeddings using the loaded model."""
    device = next(model.parameters()).device
    embeddings_list = []

    with torch.no_grad():
        for i in range(0, len(patches_tensor), batch_size):
            batch = patches_tensor[i : i + batch_size].to(device)

            try:
                # Try TerraMind format first
                batch_embeddings = model({"S2RGB": batch})
            except (TypeError, KeyError, RuntimeError):
                # Fall back to standard tensor input
                try:
                    batch_embeddings = model(batch)

                    # Handle different return types
                    if isinstance(batch_embeddings, list):
                        batch_embeddings = batch_embeddings[-1]
                    elif isinstance(batch_embeddings, tuple):
                        batch_embeddings = batch_embeddings[0]
                except Exception:
                    # Last resort: try features extraction
                    if hasattr(model, "forward_features"):
                        batch_embeddings = model.forward_features(batch)
                    elif hasattr(model, "features"):
                        features = model.features(batch)
                        batch_embeddings = torch.nn.functional.adaptive_avg_pool2d(
                            features, (1, 1)
                        ).flatten(1)
                    else:
                        raise Exception("Cannot extract embeddings from this model")

            # Ensure 2D embeddings
            if hasattr(batch_embeddings, "dim") and batch_embeddings.dim() > 2:
                spatial_dims = tuple(range(2, batch_embeddings.dim()))
                batch_embeddings = torch.mean(batch_embeddings, dim=spatial_dims)

            embeddings_list.append(batch_embeddings.cpu().numpy())

            if (i // batch_size + 1) % 10 == 0:
                print(f"Processed {i + len(batch)}/{len(patches_tensor)} patches")

    return np.vstack(embeddings_list)


# Generate embeddings
logger.info("Generating embeddings...")
embeddings = generate_embeddings(patches_tensor, model, batch_size=16)

print(f"üéØ Generated embeddings shape: {embeddings.shape}")
print("üìä Embedding statistics:")
print(f"   Mean: {np.mean(embeddings):.4f}")
print(f"   Std:  {np.std(embeddings):.4f}")
print(f"   Min:  {np.min(embeddings):.4f}")
print(f"   Max:  {np.max(embeddings):.4f}")

# Calculate cosine similarity between first 10 embeddings
if len(embeddings) > 1:
    similarity_matrix = cosine_similarity(embeddings[:10])
    avg_similarity = np.mean(similarity_matrix)
    print(f"   Avg cosine similarity (first 10): {avg_similarity:.4f}")

## 8. Dimensionality Reduction

Reduce embeddings to 3D for visualization using PCA and t-SNE.

In [None]:
# Subsample embeddings for visualization (if too many)
n_vis = min(1000, len(embeddings))
if n_vis < len(embeddings):
    indices = np.random.choice(len(embeddings), n_vis, replace=False)
    embeddings_vis = embeddings[indices]
    print(f"üìâ Subsampled {n_vis} embeddings for visualization")
else:
    embeddings_vis = embeddings
    indices = np.arange(len(embeddings))

# Apply PCA for initial dimensionality reduction
print("üîÑ Applying PCA...")
pca = PCA(n_components=50)  # Reduce to 50D first
embeddings_pca = pca.fit_transform(embeddings_vis)
print(
    f"üìä PCA explained variance ratio (first 5 components): {pca.explained_variance_ratio_[:5]}"
)
print(
    f"üìà Total variance explained by 50 components: {pca.explained_variance_ratio_.sum():.3f}"
)

# Apply t-SNE for 3D visualization
print("üîÑ Applying t-SNE for 3D reduction...")
tsne = TSNE(
    n_components=3, random_state=42, perplexity=min(30, len(embeddings_vis) - 1)
)
embeddings_3d = tsne.fit_transform(embeddings_pca)

print(f"‚úÖ Reduced to 3D: {embeddings_3d.shape}")

# Also create PCA 3D for comparison
pca_3d = PCA(n_components=3)
embeddings_pca_3d = pca_3d.fit_transform(embeddings_vis)

print(f"üìä PCA 3D explained variance: {pca_3d.explained_variance_ratio_.sum():.3f}")

# Calculate colors based on embedding magnitudes
embedding_norms = np.linalg.norm(embeddings_vis, axis=1)
colors = (embedding_norms - embedding_norms.min()) / (
    embedding_norms.max() - embedding_norms.min()
)

## 9. Interactive 3D Visualization with HoloViews

Create interactive 3D scatter plots of the embedding space.

In [None]:
# Prepare data for visualization
def create_scatter_data(coords_3d, colors, method_name):
    """Create data dictionary for scatter plot."""
    return {
        "x": coords_3d[:, 0],
        "y": coords_3d[:, 1],
        "z": coords_3d[:, 2] if coords_3d.shape[1] > 2 else coords_3d[:, 0],
        "color": colors,
        "method": [method_name] * len(coords_3d),
        "patch_id": indices,
    }

# Create datasets
tsne_data = create_scatter_data(embeddings_3d, colors, "t-SNE")
pca_data = create_scatter_data(embeddings_pca_3d, colors, "PCA")

if HV_AVAILABLE:
    # Create HoloViews 2D scatter plots (3D scatter may not be available)
    opts_2d = {
        "width": 600,
        "height": 500,
        "color": "color",
        "cmap": "viridis",
        "size": 4,
        "alpha": 0.7,
        "colorbar": True,
        "tools": ["hover"],
    }

    # t-SNE plot
    tsne_plot = hv.Scatter(
        tsne_data, kdims=["x", "y"], vdims=["color", "patch_id"]
    ).opts(title="t-SNE Embedding Space", **opts_2d)

    # PCA plot
    pca_plot = hv.Scatter(
        pca_data, kdims=["x", "y"], vdims=["color", "patch_id"]
    ).opts(title="PCA Embedding Space", **opts_2d)

    print("üé® Created interactive scatter plots!")
    print("üí° Color represents embedding magnitude")
    print("üñ±Ô∏è Use mouse to zoom and explore")

    # Display plots side by side
    layout = (tsne_plot + pca_plot).cols(2)
    display(layout)  # Explicitly display instead of bare expression
else:
    # Fallback to matplotlib plots
    fig, axes = plt.subplots(1, 2, figsize=(15, 6))
    
    # t-SNE plot
    scatter1 = axes[0].scatter(
        embeddings_3d[:, 0], embeddings_3d[:, 1],
        c=colors, cmap="viridis", alpha=0.7, s=10
    )
    axes[0].set_title("t-SNE Embedding Space")
    axes[0].set_xlabel("Component 1")
    axes[0].set_ylabel("Component 2")
    plt.colorbar(scatter1, ax=axes[0])
    
    # PCA plot
    scatter2 = axes[1].scatter(
        embeddings_pca_3d[:, 0], embeddings_pca_3d[:, 1],
        c=colors, cmap="viridis", alpha=0.7, s=10
    )
    axes[1].set_title("PCA Embedding Space")
    axes[1].set_xlabel("PC 1")
    axes[1].set_ylabel("PC 2")
    plt.colorbar(scatter2, ax=axes[1])
    
    plt.tight_layout()
    plt.show()
    print("üìä Created 2D visualization with matplotlib")

## 10. Advanced Embedding Analysis

Analyze the structure and characteristics of the generated embeddings.

In [None]:
# Analyze embedding dimensions
dim_means = np.mean(embeddings, axis=0)
dim_stds = np.std(embeddings, axis=0)

# Find most informative dimensions
most_variable_dims = np.argsort(dim_stds)[-10:]
highest_activation_dims = np.argsort(np.abs(dim_means))[-10:]

print("üìä Embedding Analysis:")
print(f"   Total dimensions: {embeddings.shape[1]}")
print(f"   Most variable dimensions: {most_variable_dims}")
print(f"   Highest activation dimensions: {highest_activation_dims}")

# Create distribution plots
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Embedding magnitude distribution
axes[0, 0].hist(embedding_norms, bins=50, alpha=0.7, color="skyblue")
axes[0, 0].set_title("Distribution of Embedding Magnitudes")
axes[0, 0].set_xlabel("L2 Norm")
axes[0, 0].set_ylabel("Frequency")

# Dimension variance plot
axes[0, 1].plot(np.sort(dim_stds)[::-1], color="orange")
axes[0, 1].set_title("Dimension Standard Deviations (Sorted)")
axes[0, 1].set_xlabel("Dimension Rank")
axes[0, 1].set_ylabel("Standard Deviation")
axes[0, 1].set_yscale("log")

# Cosine similarity heatmap (subset)
n_sample = min(50, len(embeddings))
sample_indices = np.random.choice(len(embeddings), n_sample, replace=False)
similarity_subset = cosine_similarity(embeddings[sample_indices])

im = axes[1, 0].imshow(similarity_subset, cmap="coolwarm", vmin=0, vmax=1)
axes[1, 0].set_title(f"Cosine Similarity Matrix ({n_sample} samples)")
axes[1, 0].set_xlabel("Patch Index")
axes[1, 0].set_ylabel("Patch Index")
plt.colorbar(im, ax=axes[1, 0])

# Most variable dimensions
axes[1, 1].bar(
    range(len(most_variable_dims)),
    dim_stds[most_variable_dims],
    color="green",
    alpha=0.7,
)
axes[1, 1].set_title("10 Most Variable Dimensions")
axes[1, 1].set_xlabel("Dimension Index")
axes[1, 1].set_ylabel("Standard Deviation")
axes[1, 1].set_xticks(range(len(most_variable_dims)))
axes[1, 1].set_xticklabels(most_variable_dims, rotation=45)

plt.tight_layout()
plt.show()

# Print summary statistics
print("\nüéØ Summary Statistics:")
print(f"   Mean embedding magnitude: {np.mean(embedding_norms):.4f}")
print(f"   Std embedding magnitude: {np.std(embedding_norms):.4f}")
print(f"   Mean pairwise cosine similarity: {np.mean(similarity_subset):.4f}")
print(
    f"   Dimension with highest variance: {most_variable_dims[-1]} (œÉ={dim_stds[most_variable_dims[-1]]:.4f})"
)
print(
    f"   Dimension with highest activation: {highest_activation_dims[-1]} (Œº={dim_means[highest_activation_dims[-1]]:.4f})"
)

## 11. Save Results

Save embeddings and visualization data for future use.

In [None]:
# Create output directory
output_dir = Path("../outputs")
output_dir.mkdir(exist_ok=True)

# Save embeddings
embeddings_file = output_dir / "notebook_embeddings.npy"
np.save(embeddings_file, embeddings)

# Save 3D coordinates
np.save(output_dir / "embeddings_tsne_3d.npy", embeddings_3d)
np.save(output_dir / "embeddings_pca_3d.npy", embeddings_pca_3d)

# Save metadata
metadata = {
    "num_patches": len(embeddings),
    "embedding_dim": embeddings.shape[1],
    "original_image_shape": rgb_array.shape,
    "patch_size": 16,
    "area": "Auckland, New Zealand",
    "bbox": BBOX,
    "datetime": DATETIME,
    "model_type": type(model).__name__,
    "statistics": {
        "mean": float(np.mean(embeddings)),
        "std": float(np.std(embeddings)),
        "min": float(np.min(embeddings)),
        "max": float(np.max(embeddings)),
    },
}

with open(output_dir / "notebook_metadata.json", "w") as f:
    json.dump(metadata, f, indent=2)

print(f"üíæ Saved results to {output_dir}:")
print(f"   üìÅ embeddings: {embeddings_file}")
print("   üìÅ 3D coordinates: embeddings_tsne_3d.npy, embeddings_pca_3d.npy")
print("   üìÅ metadata: notebook_metadata.json")
print("\nüéâ TerraMind embedding generation completed successfully!")
print(f"üìä Generated {len(embeddings)} embeddings from {len(patches)} patches")
print("üé® Interactive 3D visualization shows embedding space structure")

## üéâ Demo Complete!

This notebook demonstrated the complete workflow for generating geospatial embeddings from satellite imagery:

### What We Accomplished

1. **üì° Connected to Element84 Earth Search** - Accessed cloud-native STAC catalog
2. **üõ∞Ô∏è Loaded Sentinel-2 imagery** - Used odc-stac for efficient data loading  
3. **üñºÔ∏è Created RGB composites** - Processed satellite data for model input
4. **ü§ñ Loaded foundation models** - Used TerraTorch with robust fallback system
5. **‚úÇÔ∏è Extracted image patches** - Prepared 16x16 pixel patches for embedding generation
6. **üß† Generated embeddings** - Created high-dimensional feature representations
7. **üìä Applied dimensionality reduction** - Used PCA and t-SNE for visualization
8. **üé® Created 3D visualizations** - Interactive exploration of embedding space

### Key Insights

- **Embedding Structure**: The 3D visualizations reveal the underlying structure in how the foundation model represents different image patches
- **Similarity Patterns**: Patches with similar visual characteristics cluster together in embedding space
- **Dimensionality**: Foundation models capture rich representations that can be effectively reduced for visualization
- **Geospatial Context**: The embeddings preserve spatial relationships and land cover patterns

### Next Steps

- **Classification**: Use embeddings for land cover classification tasks
- **Change Detection**: Compare embeddings across time periods
- **Similarity Search**: Find similar landscape patterns across different regions
- **Model Training**: Fine-tune foundation models using these embeddings as features

### Resources

- [odc-stac Documentation](https://github.com/opendatacube/odc-stac)
- [TerraTorch GitHub](https://github.com/IBM/terratorch)
- [Element84 Earth Search](https://github.com/element84/earth-search)
- [HoloViews Documentation](https://holoviews.org/)