# Foundation Models for Earth Observation

This notebook demonstrates using embeddings from pre-trained foundation models for crop type classification.

## Learning Objectives
- Load pre-computed embeddings from the Alpha Earth Foundation model
- Visualize high-dimensional embeddings
- Apply unsupervised clustering
- Train supervised classifiers for crop type classification
- Test generalization across regions

## What are Foundation Model Embeddings?
Foundation models learn representations from large amounts of satellite imagery. These representations are called "embeddings" - compact numerical features that capture patterns in the imagery. We use these embeddings as features for downstream tasks like crop classification.

We'll be using [Google Alphaearth Embeddings](https://arxiv.org/pdf/2507.22291), but various other embedding models exist. Not all models are created equal, so its important to understand how a model was pretrained and what limitations it has.

## Setup

Install required packages and import helper functions.

In [None]:
# Install dependencies
%pip install \
  numpy==1.26.4 \
  scipy==1.11.4 \
  scikit-learn==1.3.2 \
  pandas==2.1.4 \
  matplotlib==3.8.4 \
  rasterio==1.3.9 \
  xarray==2024.7.0 \
  pyproj==3.6.1 \
  dask==2024.7.1 \
  rioxarray==0.15.7 \
  pystac-client==0.7.6 \
  planetary-computer==1.0.0 \
  pyarrow==14.0.2 \
  tqdm==4.66.4 \
  leafmap==0.34.3 \
  seaborn \
  -q

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
import pandas as pd

# Import our helper functions
from geo_helpers import (
    load_sentinel2_rgb,
    load_sentinel2_rgb_timeseries,
    load_crop_labels,
    load_foundation_model_embeddings,
    create_embedding_rgb,
    simplify_crop_labels,
    align_labels_to_embeddings,
    prepare_training_data,
    get_class_names,
    print_crop_statistics,
    prepare_csv_samples,
    show_split_statistics,
    compute_clusters,
    predict_on_embeddings,
    prepare_sample_coordinates
)

from viz_helpers import (
    plot_rgb_image,
    plot_rgb_timeseries,
    plot_crop_labels,
    plot_embeddings_rgb,
    plot_clustering_results,
    plot_classification_results,
    plot_prediction_map,
    plot_generalization_comparison,
    show_study_area_map,
    plot_classification_vs_clustering,
    plot_sample_map_by_class,
    plot_sample_map_by_split,
    plot_classification_vs_clustering
)

## 1. Define Study Area

We'll start with a region in mississipi, USA - a major corn and soybean production area.

In [None]:
# Mississippi Delta region - corn/soy agricultural area
LON_MIN, LON_MAX = -90.90, -90.75
LAT_MIN, LAT_MAX = 33.45, 33.60
YEAR = 2024

# Visualize the study area on an interactive map
print("Study Area")
show_study_area_map(LON_MIN, LON_MAX, LAT_MIN, LAT_MAX)

## 2. Load Data

Loading three data sources:
- Sentinel-2 RGB imagery
- Crop labels from USDA Cropland Data Layer
- Foundation model embeddings (64-dimensional features)

In [None]:
# Load Sentinel-2 RGB imagery for growing season
# Mississippi growing season: April-October for corn/soy
GROWING_SEASON_MONTHS = [4, 5, 6, 7, 8, 9, 10]  # Apr-Oct

rgb_timeseries = load_sentinel2_rgb_timeseries(
    LON_MIN, LON_MAX, LAT_MIN, LAT_MAX, YEAR, GROWING_SEASON_MONTHS
)

plot_rgb_timeseries(rgb_timeseries, "Mississippi Delta", YEAR)

In [None]:
# Load foundation model embeddings (64-dimensional feature vector for each pixel)
embeddings = load_foundation_model_embeddings(LON_MIN, LON_MAX, LAT_MIN, LAT_MAX, YEAR)

print(f"Embedding shape: {embeddings.shape}")
print(f"  - 64 feature dimensions capturing crop characteristics")
print(f"  - {embeddings.shape[1]} x {embeddings.shape[2]} pixels")


## 3. Visualize Embeddings

Embeddings are 64-dimensional, so we can't visualize them directly. We'll create a false color RGB by mapping 3 dimensions to the red, green, and blue channels.

Try different dimensions to see what features they capture.

In [None]:
# Create RGB visualization from 3 embedding dimensions
dimensions_to_visualize = [0, 10, 2]
embedding_rgb = create_embedding_rgb(embeddings, bands=dimensions_to_visualize)
plot_embeddings_rgb(embedding_rgb, bands=dimensions_to_visualize)

## 4. Unsupervised Clustering

Embeddings are information-rich. Clustering them often reveals patterns that map to real-world features - fields with the same crop, built-up areas vs vegetation, etc.

Experiment with the number of clusters to see different levels of granularity.

Note: More clusters = longer processing time.

In [None]:
# Try k-means clustering with different numbers of clusters
number_of_clusters_to_explore = [3, 5, 10]  # Try changing this

# Compute clusters
cluster_results = compute_clusters(embeddings, k_values=number_of_clusters_to_explore)

# Plot results
plot_clustering_results(cluster_results=cluster_results)

## 5. Training Soy vs Corn vs Rest Classifier (Supervised)

Pre-computed embeddings are convenient - instead of downloading full satellite imagery tiles, we just fetch embeddings for specific points.

We've extracted AlphaEarth embeddings using Google Earth Engine for points across Mississippi, with labels from the USDA Cropland Data Layer (CDL is model-predicted, not true ground truth). The CSV contains 1000 samples per class.

Let's load the data and train a classifier.

In [None]:
# Load pre-extracted Mississippi embeddings

samples_mississipi = pd.read_csv("demo_data/mississippi_alphaearth_2024.csv")
print(f"Loaded {len(samples_mississipi)} samples with embeddings")
print(f"\nClasses: {samples_mississipi['class_name'].value_counts().to_dict()}")

In [None]:
# Prepare training and test splits
NUM_SAMPLES = {
    'corn': {'label': 0, 'n_train': 20, 'n_test': 100},
    'soy': {'label': 1, 'n_train': 20, 'n_test': 100},
    'other': {'label': 2, 'n_train': 20, 'n_test': 50}
}

X_train, y_train, X_test, y_test, train_idx, test_idx = \
    prepare_csv_samples(samples_mississipi, NUM_SAMPLES, random_seed=42)

show_split_statistics(samples_mississipi, train_idx, test_idx)

### Spatial Distribution of Samples

Understanding spatial distribution matters:
- Identifying spatial autocorrelation (nearby samples are often similar)
- Detecting spatial bias
- Understanding coverage of training vs test data

We'll create interactive maps showing samples and their train/test split.

**Note:** Some samples appear near water because CDL classifies all land pixels across Mississippi, including wetlands, floodplains, and transitional areas. The state has major water bodies (Mississippi River, Gulf Coast).


In [None]:
# Prepare sample coordinates for visualization
samples_mississipi = prepare_sample_coordinates(samples_mississipi, train_idx, test_idx)

In [None]:
# Visualize all samples by crop class
m_all = plot_sample_map_by_class(samples_mississipi, center_lat=33.5, center_lon=-90.8, zoom=7)
m_all

In [None]:
# Visualize train vs test split
m_split = plot_sample_map_by_split(samples_mississipi, center_lat=33.5, center_lon=-90.8, zoom=7)
m_split

### Understanding Spatial Autocorrelation

Spatial autocorrelation: observations at nearby locations tend to be similar.

Tobler's First Law of Geography: "Everything is related to everything else, but near things are more related than distant things"

**Why this matters:**
- Train/test samples that are close together can inflate accuracy metrics
- Models may learn location-specific patterns rather than generalizable features
- Violates the i.i.d. assumption of standard ML

**In crop classification:**
- Nearby fields often have similar crops
- Environmental conditions vary smoothly across space
- Management practices are regional
- Satellite imagery captures spatial patterns

Always test on geographically separated regions.


In [None]:
# Train classifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

print("Training Random Forest classifier...")
clf = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1)
clf.fit(X_train, y_train)

# Evaluate
y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"Test Accuracy: {accuracy:.2%}")

# Show results
plot_classification_results(y_test, y_pred, accuracy, ['other', 'corn', 'soy'])

## 6. Apply Classifier to Our Study Area

Now let's apply our trained classifier to the embeddings we loaded for our study area and compare it with the clustering results!

In [None]:
# Apply classifier to spatial embeddings from our study area

print("Predicting crop types for study area...")
crop_predictions = predict_on_embeddings(clf, embeddings)

print(f" Predictions complete!")
print(f"Map shape: {crop_predictions.shape}")
print(f"\nPredicted crop distribution:")
for label, name in enumerate(['other', 'corn', 'soy']):
    count = np.sum(crop_predictions == label)
    pct = 100 * count / crop_predictions.size
    print(f"  {name.capitalize()}: {pct:.1f}%")

In [None]:
# Compare classification with clustering using precomputed clusters

print("Comparing supervised classification with unsupervised clustering...")

# Display side-by-side comparison
plot_classification_vs_clustering(
    embeddings, 
    None,
    crop_predictions, 
    {5: cluster_results[5]},  # Only use k=5 cluster from precomputed results
    class_names=['other', 'corn', 'soy']
)

**Understanding Limitations**

The unsupervised clustering looks cleaner - field borders and roads are clearly visible. Why doesn't the supervised model show the same?

This highlights the importance of understanding your training data, not just validation metrics.

Our labeled samples come from CDL at 30m resolution where roads aren't clearly visible - they're often merged into field pixels. The model never learned to separate roads from crops, so it classifies road pixels as corn or soy.

## 7. Testing Model Generalization Across Regions

### Can Foundation Models Generalize?

We trained our classifier with samples from Mississippi and achieved great results! But in real-world applications (especially disaster response), we often need models to work in completely new regions.

**The Challenge:** Mississippi and Minnesota are ~1,400km apart with very different environmental conditions:

| Factor | Mississippi | Minnesota |
|--------|-------------|-----------|
| **Latitude** | 30.2N - 35.0°N | 43.5°N - 49.4°N |
| **Climate** | Humid Subtropical | Continental |
| **Growing Season** | 240-270 days | 120-160 days |
| **Summer Temperature** | 27-29°C | 20-23°C |
| **Winter Temperature** | 8-11°C | -12 to -7°C |
| **Annual Rainfall** | 1,400-1,600 mm | 500-750 mm |
| **Soil Type** | Clay loam, alluvial | Mollisols, glacial till |

**Why this matters for ML:**
- Different spectral signatures due to climate
- Different crop phenology (growth timing)
- Different stress patterns (heat vs cold)
- Testing true generalization capability

These differences create a "domain shift" - can our model handle it? Let's find out!

### Load Minnesota Data for Testing

In [None]:
# Load Minnesota data
samples_minnesota = pd.read_csv("demo_data/minnesota_alphaearth_2024.csv")

# Prepare test set (300 samples per class)
embedding_columns = [col for col in samples_minnesota.columns if col.startswith('A')]
X_minnesota_all = samples_minnesota[embedding_columns].values
y_minnesota_all = samples_minnesota['label'].values

# Sample test set
np.random.seed(42)
test_indices_mn = []
for label in [0, 1, 2]:
    class_idx = np.where(y_minnesota_all == label)[0]
    sampled = np.random.choice(class_idx, min(300, len(class_idx)), replace=False)
    test_indices_mn.extend(sampled)

test_indices_mn = np.array(test_indices_mn)
np.random.shuffle(test_indices_mn)

X_test_minnesota = X_minnesota_all[test_indices_mn]
y_test_minnesota = y_minnesota_all[test_indices_mn]

print(f"Minnesota test set: {len(X_test_minnesota)} samples")
for label in [0, 1, 2]:
    count = np.sum(y_test_minnesota == label)
    name = samples_minnesota[samples_minnesota['label'] == label]['class_name'].iloc[0]
    print(f"  {name}: {count} samples")

In [None]:
# Test Mississippi-trained model on Minnesota (Zero-Shot Transfer)
print("Testing zero-shot transfer: Mississippi model → Minnesota data")
print("="*60)

y_pred_zeroshot = clf.predict(X_test_minnesota)
accuracy_zeroshot = accuracy_score(y_test_minnesota, y_pred_zeroshot)

print(f"\n Zero-Shot Accuracy: {accuracy_zeroshot:.2%}")
print(f"\n  Trained on: {len(X_train)} Mississippi samples")
print(f"  Tested on: {len(X_test_minnesota)} Minnesota samples")
print(f"  Distance: ~1,400 km apart")

# Show results
plot_classification_results(y_test_minnesota, y_pred_zeroshot, accuracy_zeroshot, 
                           ['other', 'corn', 'soy'])

### Few-Shot Adaptation

Performance dropped significantly due to distribution shift between Mississippi and Minnesota.

Can we improve performance by adding a few samples from the target region?

Let's add **5 Minnesota samples per class** (15 total) to our training set.

In [None]:
# Add 5 Minnesota samples per class to training set
N_FEWSHOT = 5

# Get samples not in test set
remaining_idx = np.setdiff1d(np.arange(len(y_minnesota_all)), test_indices_mn)

fewshot_indices = []
for label in [0, 1, 2]:
    class_remaining = remaining_idx[y_minnesota_all[remaining_idx] == label]
    sampled = np.random.choice(class_remaining, N_FEWSHOT, replace=False)
    fewshot_indices.extend(sampled)

fewshot_indices = np.array(fewshot_indices)

X_fewshot = X_minnesota_all[fewshot_indices]
y_fewshot = y_minnesota_all[fewshot_indices]

# Combine with Mississippi training data
X_train_combined = np.vstack([X_train, X_fewshot])
y_train_combined = np.concatenate([y_train, y_fewshot])

print(f"Combined training set:")
print(f"  Mississippi: {len(X_train)} samples")
print(f"  Minnesota: {len(X_fewshot)} samples")
print(f"  Total: {len(X_train_combined)} samples")

In [None]:
# Train combined model
print("Training model with Mississippi + Minnesota few-shot samples...")

clf_combined = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1)
clf_combined.fit(X_train_combined, y_train_combined)

# Test on Minnesota
y_pred_fewshot = clf_combined.predict(X_test_minnesota)
accuracy_fewshot = accuracy_score(y_test_minnesota, y_pred_fewshot)

print(f"\n Few-Shot Accuracy: {accuracy_fewshot:.2%}")
print(f"  Improvement: {accuracy_fewshot - accuracy_zeroshot:+.2%}")

# Show results
plot_classification_results(y_test_minnesota, y_pred_fewshot, accuracy_fewshot,
                           ['other', 'corn', 'soy'])

In [None]:
# Visualize comparison
from sklearn.metrics import confusion_matrix
import seaborn as sns

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

class_names = ['other', 'corn', 'soy']

# Zero-shot confusion matrix
cm_zero = confusion_matrix(y_test_minnesota, y_pred_zeroshot)
sns.heatmap(cm_zero, annot=True, fmt='d', cmap='Reds', ax=axes[0],
            xticklabels=class_names, yticklabels=class_names)
axes[0].set_title(f'Zero-Shot Transfer\n(Mississippi only)\nAccuracy: {accuracy_zeroshot:.2%}',
                  fontsize=13, fontweight='bold')
axes[0].set_ylabel('True Label')
axes[0].set_xlabel('Predicted Label')

# Few-shot confusion matrix
cm_few = confusion_matrix(y_test_minnesota, y_pred_fewshot)
sns.heatmap(cm_few, annot=True, fmt='d', cmap='Greens', ax=axes[1],
            xticklabels=class_names, yticklabels=class_names)
axes[1].set_title(f'Few-Shot Adapted\n(+5 MN samples/class)\nAccuracy: {accuracy_fewshot:.2%}',
                  fontsize=13, fontweight='bold')
axes[1].set_ylabel('True Label')
axes[1].set_xlabel('Predicted Label')

plt.tight_layout()
plt.show()

## Key Takeaways

### What We Learned About Foundation Models

1. **Rich Embeddings Capture Crop Identity**
   - 64-dimensional embeddings from foundation models encode meaningful crop characteristics
   - Unsupervised clustering reveals natural patterns that often align with crop types
   - Can be visualized as false-color images to understand what the model "sees"

2. **Efficient Training with Few Samples**
   - We can train a supervised classifier with only a few samples per class
   - Traditional ML would need typically need hundreds or thousands of samples
   - Foundation models transfer knowledge from pre-training on massive datasets

3. **Zero-Shot Transfer Works Across Regions**
   - Model trained on Mississippi achieves decent accuracy on Minnesota
   - ~1,400km apart with very different climate and growing conditions
   - Shows foundation models learn generalizable crop features

4. **Few-Shot Adaptation is Powerful**
   - Adding just 5 samples per class (15 total) significantly improves performance
   - Orders of magnitude less data than traditional approaches
   - Critical for rapid deployment in disaster response scenarios

### Practical Applications

- **Disaster Response**: Quickly adapt models to new regions with minimal data collection
- **Global Monitoring**: Train in data-rich areas, deploy to data-poor regions
- **Cost Efficiency**: Reduce field campaign costs dramatically
- **Rapid Prototyping**: Test approaches with small samples before scaling

### Important Considerations

- **Spatial Autocorrelation**: Always test on geographically separated regions
- **Domain Shift**: Understand environmental differences between train/test regions
- **Validation Strategy**: In this example we only used a train and test set for demonstration purposes. In real world application, you should always use proper train/validation/test splits for hyperparameter tuning or use k-fold splits.
- **Ground Truth Quality**: Be aware that CDL is predicted data, not perfect ground truth

### Try Experimenting!

- Change the number of training samples (5, 10, 50 per class)
- Try different numbers of clusters for unsupervised learning
- Visualize different embedding dimensions
- Test on other regions (You will have to create a GEE account and extract the samples from there) 
- Compare different classifiers (SVM, XGBoost, Neural Networks)