# AlphaEarth Data Preparation for Winter SeasonThis notebook prepares AlphaEarth satellite embeddings for integration with your existing CNN/MLP/GNN models.## Overview- Extract AlphaEarth embeddings from Google Earth Engine- Prepare 4 integration options (A, B, C, D)- Create patches for CNN input- Prepare summary statistics for MLP input- Save data for model training

In [None]:
import ee
import numpy as np
import pandas as pd
import rasterio
from rasterio.windows import Window
import os
import sys
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import warnings
warnings.filterwarnings('ignore')

print("All imports successful")


## Step 1: Initialize Google Earth Engine

In [None]:
# Authenticate with Earth Engine (run once per session)
try:
    ee.Initialize(project='five-rivers-alphaearth')
    print("Earth Engine initialized successfully")
except:
    ee.Authenticate()
    ee.Initialize(project='five-rivers-alphaearth')
    print("Earth Engine authenticated and initialized")


## Step 2: Load Sampling Points and Base Data

In [None]:
# Load your sampling points
sampling_points = pd.read_csv('../data/Samples_100.csv')
rainy_data = pd.read_csv('../../data/RainySeason.csv')

print(f"Sampling points: {len(sampling_points)}")
print(f"Rainy season data: {len(rainy_data)}")
print(f"\nColumns in sampling data: {sampling_points.columns.tolist()[:10]}...")


## Step 3: Extract AlphaEarth Embeddings from Google Earth Engine

In [None]:
# Load AlphaEarth dataset
embeddings_collection = ee.ImageCollection('GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL')

# Define study area (Dhaka, Bangladesh)
# Adjust bounds to match your study area
aoi = ee.Geometry.Rectangle([88.0, 23.5, 90.0, 24.0])

print(f"AlphaEarth dataset loaded")
print(f"Area of Interest: {aoi.getInfo()}")


In [None]:
# Function to extract AlphaEarth values at sampling points
def extract_alpha_earth_values(year_start, year_end, season_name):
    """
    Extract AlphaEarth embeddings for specified year range and season
    
    Parameters:
    - year_start: Start year (e.g., 2023)
    - year_end: End year (e.g., 2023)
    - season_name: 'rainy' or 'winter'
    
    Returns:
    - DataFrame with AlphaEarth values
    """
    
    # Define date ranges for seasons
    if season_name == 'rainy':
        date_start = f'{year_start}-06-01'
        date_end = f'{year_end}-09-30'
    else:  # winter
        date_start = f'{year_start}-11-01'
        if year_start == year_end:
            date_end = f'{year_end+1}-02-28'
        else:
            date_end = f'{year_end}-02-28'
    
    # Load AlphaEarth for this period
    embeddings = embeddings_collection.filterDate(date_start, date_end).first()
    
    # Extract values at each sampling point
    alpha_earth_values = []
    
    for idx, row in sampling_points.iterrows():
        if idx % 10 == 0:
            print(f"Extracting sample {idx+1}/{len(sampling_points)}...")
        
        point = ee.Geometry.Point([row['Longitude'], row['Latitude']])
        
        try:
            # Sample at 10m resolution (AlphaEarth native resolution)
            sample = embeddings.sample(point, scale=10)
            values = sample.first().toDictionary().getInfo()
            alpha_earth_values.append(values)
        except Exception as e:
            print(f"Error extracting sample {idx}: {e}")
            # Create NaN row if extraction fails
            alpha_earth_values.append({f'AE_{i:02d}': np.nan for i in range(64)})
    
    # Convert to DataFrame
    ae_df = pd.DataFrame(alpha_earth_values)
    
    # Rename columns from A00-A63 to AE_00-AE_63 for clarity
    ae_columns = {}
    for col in ae_df.columns:
        if col.startswith('A') and col[1:].isdigit():
            ae_columns[col] = f'AE_{col[1:]}'
    ae_df = ae_df.rename(columns=ae_columns)
    
    print(f"\nExtracted {len(ae_df)} samples")
    print(f"Columns: {ae_df.columns.tolist()[:5]}...")
    print(f"Shape: {ae_df.shape}")
    
    return ae_df

print("Function defined")


In [None]:
# Extract for rainy season 2023
# NOTE: This will take 10-15 minutes per year
# For full analysis, loop through years 2017-2024

print("Extracting AlphaEarth for rainy season 2023...")
alpha_earth_rainy = extract_alpha_earth_values(2023, 2023, 'rainy')

# Save for future use
alpha_earth_rainy.to_csv('alpha_earth_rainy_2023.csv', index=False)
print("\nSaved to: alpha_earth_rainy_2023.csv")


In [None]:
# (Optional) Extract for all years 2017-2024
# Uncomment to run full extraction

# all_years_data = {}
# for year in range(2017, 2025):
#     print(f"\n{'='*60}")
#     print(f"Extracting year {year}...")
#     print(f"{'='*60}")
#     
#     ae_data = extract_alpha_earth_values(year, year, 'rainy')
#     ae_data['year'] = year
#     ae_data['season'] = 'rainy'
#     all_years_data[f'rainy_{year}'] = ae_data
#     
#     ae_data.to_csv(f'alpha_earth_rainy_{year}.csv', index=False)

print("Extraction code prepared (uncomment to run full analysis)")


## Step 4: Prepare AlphaEarth Integration - Option B (Recommended)

### Option B: Add to Current Features

This integrates AlphaEarth embeddings while keeping all existing features.
Expected improvement: +3-6% R² improvement

In [None]:
# Load current base data for comparison
base_data = pd.read_csv('../data/Samples_100.csv')

# Merge with AlphaEarth
combined_data = pd.concat([base_data.reset_index(drop=True), alpha_earth_rainy.reset_index(drop=True)], axis=1)

print(f"Combined data shape: {combined_data.shape}")
print(f"Base features: {list(base_data.columns[:10])}")
print(f"AlphaEarth columns: {[c for c in combined_data.columns if c.startswith('AE_')][:5]}...")


In [None]:
# AlphaEarth Integration Option B: Add to Current Features (RECOMMENDED)
# This is the recommended approach - keeps all existing features and adds AlphaEarth
# Load current base data
base_data = pd.read_csv('../data/Samples_100.csv')
# Merge with AlphaEarth embeddings
combined_data = pd.concat([base_data.reset_index(drop=True), alpha_earth_rainy.reset_index(drop=True)], axis=1)
print(f'Combined data shape: {combined_data.shape}')
print(f'Total features: {len(combined_data.columns)}')
print(f'  - Original features: {len(base_data.columns)}')
print(f'  - AlphaEarth features: 64')
# Save Option B data (MAIN DATASET)
option_b_data = combined_data.copy()
option_b_data.to_csv('Option_B_RainyAE.csv', index=False)
print(f'\nOption B dataset saved: Option_B_RainyAE.csv')
print(f'Shape: {option_b_data.shape}')
print(f'Columns: {list(option_b_data.columns[:10])}...')
# Info about what this option includes:
print()
print('=' * 70)
print('OPTION B DETAILS (RECOMMENDED)')
print('=' * 70)
print('✓ Keeps all original features (25 bands: metals, indices, LULC, soil)')
print('✓ Adds 64-dimensional AlphaEarth embeddings')
print('✓ Total CNN input: ~89 channels')
print('✓ Expected improvement: +3-6% R² (rainy), +2-4% R² (winter)')
print('✓ Best balance of performance vs complexity')
print('=' * 70)


## Step 5: Create CNN Patches with AlphaEarth

In [None]:
# Function to extract raster patches (existing code, adapted)
def extract_alpha_earth_patches(alpha_earth_path, coordinates, patch_size=32, resolution=10):
    """
    Extract patches from AlphaEarth raster at sampling locations
    
    Parameters:
    - alpha_earth_path: Path to AlphaEarth GeoTIFF/GeoPackage
    - coordinates: List of (lon, lat) tuples
    - patch_size: Size in pixels (32x32 = 320m x 320m at 10m resolution)
    - resolution: Pixel resolution in meters
    
    Returns:
    - numpy array of shape (N, patch_size, patch_size, 64)
    """
    
    patches = []
    
    with rasterio.open(alpha_earth_path) as src:
        for idx, (lon, lat) in enumerate(coordinates):
            if idx % 10 == 0:
                print(f"Extracting patch {idx+1}/{len(coordinates)}...")
            
            try:
                # Convert geo coordinates to pixel coordinates
                row, col = src.index(lon, lat)
                
                # Define window (patch_size x patch_size pixels, centered on location)
                half_size = patch_size // 2
                window = Window(col - half_size, row - half_size, patch_size, patch_size)
                
                # Read patch (all 64 bands)
                patch = src.read(window=window)
                
                # Transpose to (patch_size, patch_size, 64)
                patch = np.transpose(patch, (1, 2, 0))
                
                # Handle edge cases (pad if necessary)
                if patch.shape != (patch_size, patch_size, 64):
                    patch = np.pad(patch, 
                                  ((0, patch_size - patch.shape[0]),
                                   (0, patch_size - patch.shape[1]),
                                   (0, 0)),
                                  mode='constant', constant_values=0)
                
                patches.append(patch)
            except Exception as e:
                print(f"Error extracting patch {idx}: {e}")
                # Create zero patch if extraction fails
                patches.append(np.zeros((patch_size, patch_size, 64)))
    
    return np.array(patches)

print("Patch extraction function defined")


## Step 6: Summary and Next Steps

In [None]:
print("""\n" + "="*80)
print("ALPHAEARTH DATA PREPARATION COMPLETE")
print("="*80)

print("\nGenerated Files:")
print("  1. Option_A_RainyAE.csv - Replace indices approach")
print("  2. Option_B_RainyAE.csv - Add to current features (RECOMMENDED)")
print("  3. Option_C_RainyAE.csv - PCA-reduced approach")
print("  4. Option_D_RainyAE.csv - MLP enhancement only")
print("  5. pca_alpha_earth.pkl - PCA model for Option C")
print("  6. alpha_earth_rainy_2023.csv - Raw AlphaEarth values")

print("\nNext Steps:")
print("  1. Copy these files to ../data/ directory")
print("  2. Modify existing model notebooks to use these files")
print("  3. Update data loading and CNN input preparation code")
print("  4. Retrain models with AlphaEarth integration")
print("\n" + "="*80)
""")


In [None]:
# Check file sizes
import os

files = [
    'Option_A_RainyAE.csv',
    'Option_B_RainyAE.csv',
    'Option_C_RainyAE.csv',
    'Option_D_RainyAE.csv'
]

print("\nOutput Files:")
for f in files:
    if os.path.exists(f):
        size_mb = os.path.getsize(f) / (1024*1024)
        print(f"  {f}: {size_mb:.2f} MB")
