# Multi-City Green Space Detection
## Training Random Forest with WorldCover 2021 as Ground Truth

**Training Cities:** 9 cities for robust model training

**Key Features:**
- Uses **WorldCover 2021** as ground truth for training
- Green classes: Tree cover (10), Shrubland (20), Grassland (30), Mangroves (95)
- Multi-temporal Sentinel-2 data (April, August, November)
- 21 bands: 4 spectral bands × 3 months + 3 vegetation indices × 3 months
- **Cross-city training** for better generalization

## 1. Import Libraries

In [None]:
import json
import os
import glob
import numpy as np
import rasterio
from rasterio.warp import reproject, Resampling
from pathlib import Path
import geopandas as gpd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

print("✓ Libraries imported successfully")

## 2. Configuration

In [None]:
# Base paths
BASE_PATH = "/Users/timgotschim/Documents/LLM/infrared.city"
STACKS_FOLDER = os.path.join(BASE_PATH, "21 Stacks")
GEOJSON_FOLDER = os.path.join(BASE_PATH, "sentinel_data")
WORLDCOVER_FOLDER = os.path.join(BASE_PATH, "worldcover")

# Output folder
OUTPUT_FOLDER = os.path.join(BASE_PATH, "multi_city_training")
os.makedirs(OUTPUT_FOLDER, exist_ok=True)

# Create timestamped run folder
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
RUN_FOLDER = os.path.join(OUTPUT_FOLDER, f"run_{timestamp}")
os.makedirs(RUN_FOLDER, exist_ok=True)

# WorldCover green classes
GREEN_CLASSES = [10, 20, 30, 95]  # Tree, Shrub, Grass, Mangroves

print("✓ Configuration loaded")
print(f"  Stacks folder: {STACKS_FOLDER}")
print(f"  GeoJSON folder: {GEOJSON_FOLDER}")
print(f"  WorldCover folder: {WORLDCOVER_FOLDER}")
print(f"  Output folder: {RUN_FOLDER}")

## 3. Discover Available Cities
### Find all cities with Multi-Month stacks

In [None]:
print("="*70)
print("DISCOVERING AVAILABLE CITIES")
print("="*70)

# Find all Multi-Month stack files
stack_files = glob.glob(os.path.join(STACKS_FOLDER, "*_MultiMonth_stack.tif"))

print(f"\nFound {len(stack_files)} Multi-Month stacks:")

cities_data = []

for stack_file in sorted(stack_files):
    # Extract city name from filename
    filename = os.path.basename(stack_file)
    city_name = filename.replace("_MultiMonth_stack.tif", "")
    
    # Try to find corresponding GeoJSON
    geojson_patterns = [
        os.path.join(GEOJSON_FOLDER, f"{city_name}.geojson"),
        os.path.join(GEOJSON_FOLDER, city_name, f"{city_name}.geojson"),
        os.path.join(GEOJSON_FOLDER, f"{city_name.lower()}.geojson"),
    ]
    
    geojson_file = None
    for pattern in geojson_patterns:
        if os.path.exists(pattern):
            geojson_file = pattern
            break
    
    # Try to find WorldCover file
    worldcover_patterns = [
        os.path.join(WORLDCOVER_FOLDER, f"{city_name}_WorldCover_2021.tif"),
        os.path.join(WORLDCOVER_FOLDER, f"{city_name}_WorldCover.tif"),
        os.path.join(WORLDCOVER_FOLDER, city_name, f"{city_name}_WorldCover_2021.tif"),
    ]
    
    worldcover_file = None
    for pattern in worldcover_patterns:
        if os.path.exists(pattern):
            worldcover_file = pattern
            break
    
    cities_data.append({
        "name": city_name,
        "stack_file": stack_file,
        "geojson_file": geojson_file,
        "worldcover_file": worldcover_file
    })
    
    status_geojson = "✓" if geojson_file else "✗"
    status_worldcover = "✓" if worldcover_file else "✗"
    
    print(f"  {city_name:20s} - Stack: ✓  GeoJSON: {status_geojson}  WorldCover: {status_worldcover}")

# Filter cities with all required data
complete_cities = [
    city for city in cities_data 
    if city["geojson_file"] and city["worldcover_file"]
]

print(f"\n{'='*70}")
print(f"Cities with complete data: {len(complete_cities)}")
print(f"{'='*70}")

if len(complete_cities) == 0:
    raise ValueError("No cities with complete data found! Check your file paths.")

for city in complete_cities:
    print(f"  ✓ {city['name']}")

## 4. Load and Process All Cities
### Load Sentinel-2 stacks and create WorldCover labels for each city

In [None]:
print("\n" + "="*70)
print("LOADING AND PROCESSING ALL CITIES")
print("="*70)

# Expected number of bands (set to None to auto-detect from first city)
EXPECTED_BANDS = None

all_X = []  # Features from all cities
all_y = []  # Labels from all cities
city_info = []  # Track which city each sample came from
skipped_cities = []  # Track skipped cities

for city_data in tqdm(complete_cities, desc="Processing cities"):
    city_name = city_data["name"]
    stack_file = city_data["stack_file"]
    worldcover_file = city_data["worldcover_file"]
    
    print(f"\n{'='*70}")
    print(f"Processing: {city_name}")
    print(f"{'='*70}")
    
    try:
        # Load Sentinel-2 stack
        with rasterio.open(stack_file) as src:
            X_stack = src.read()  # Shape: (n_bands, height, width)
            stack_transform = src.transform
            stack_shape = (src.height, src.width)
            stack_crs = src.crs
        
        n_bands = X_stack.shape[0]
        print(f"  ✓ Loaded Sentinel-2 stack: {X_stack.shape} ({n_bands} bands)")
        
        # Set expected bands from first city, or check consistency
        if EXPECTED_BANDS is None:
            EXPECTED_BANDS = n_bands
            print(f"  ℹ Setting expected bands to {EXPECTED_BANDS}")
        elif n_bands != EXPECTED_BANDS:
            print(f"  ⚠ SKIPPING: Expected {EXPECTED_BANDS} bands, but found {n_bands} bands")
            skipped_cities.append({"name": city_name, "reason": f"Band mismatch: {n_bands} vs {EXPECTED_BANDS}"})
            continue
        
        # Load and reproject WorldCover to match Sentinel-2
        with rasterio.open(worldcover_file) as src:
            worldcover_data = np.empty(stack_shape, dtype=np.uint8)
            
            reproject(
                source=rasterio.band(src, 1),
                destination=worldcover_data,
                src_transform=src.transform,
                src_crs=src.crs,
                dst_transform=stack_transform,
                dst_crs=stack_crs,
                resampling=Resampling.nearest
            )
        
        # Convert to binary green/non-green labels
        labels = np.isin(worldcover_data, GREEN_CLASSES).astype(np.uint8)
        
        green_percentage = 100 * labels.sum() / labels.size
        print(f"  ✓ WorldCover labels: {labels.shape} ({green_percentage:.2f}% green)")
        
        # Reshape for sklearn: (n_samples, n_features)
        n_pixels = X_stack.shape[1] * X_stack.shape[2]
        
        X = X_stack.reshape(n_bands, -1).T  # Shape: (n_pixels, n_bands)
        y = labels.flatten()  # Shape: (n_pixels,)
        
        # Remove NaN values
        valid_mask = ~np.isnan(X).any(axis=1)
        X_clean = X[valid_mask]
        y_clean = y[valid_mask]
        
        print(f"  ✓ Valid samples: {len(X_clean):,} ({100*len(X_clean)/n_pixels:.1f}% of pixels)")
        print(f"    - Green: {np.sum(y_clean == 1):,} ({100*np.sum(y_clean == 1)/len(y_clean):.2f}%)")
        print(f"    - Non-green: {np.sum(y_clean == 0):,} ({100*np.sum(y_clean == 0)/len(y_clean):.2f}%)")
        
        # Add to combined dataset
        all_X.append(X_clean)
        all_y.append(y_clean)
        city_info.extend([city_name] * len(X_clean))
        
    except Exception as e:
        print(f"  ✗ Error processing {city_name}: {e}")
        skipped_cities.append({"name": city_name, "reason": str(e)})
        continue

print(f"\n{'='*70}")
print("DATA AGGREGATION")
print(f"{'='*70}")

# Report skipped cities
if skipped_cities:
    print(f"\n⚠ Skipped {len(skipped_cities)} cities due to issues:")
    for city in skipped_cities:
        print(f"  - {city['name']}: {city['reason']}")

# Check if we have any data
if len(all_X) == 0:
    raise ValueError("No valid city data loaded! Check that all stacks have the same number of bands.")

# Combine all data
X_combined = np.vstack(all_X)
y_combined = np.hstack(all_y)
city_info = np.array(city_info)

print(f"\nCombined dataset:")
print(f"  Cities included: {len(all_X)}")
print(f"  Total samples: {len(X_combined):,}")
print(f"  Features (bands): {X_combined.shape[1]}")
print(f"  Green samples: {np.sum(y_combined == 1):,} ({100*np.sum(y_combined == 1)/len(y_combined):.2f}%)")
print(f"  Non-green samples: {np.sum(y_combined == 0):,} ({100*np.sum(y_combined == 0)/len(y_combined):.2f}%)")

# Save city distribution
print(f"\nSamples per city:")
unique_cities = np.unique(city_info)
for city_name in unique_cities:
    city_samples = np.sum(city_info == city_name)
    print(f"  {city_name:20s}: {city_samples:>10,} samples")

print(f"{'='*70}")

## 5. Train-Test Split
### Split data for training and validation

In [None]:
print("\n" + "="*70)
print("TRAIN-TEST SPLIT")
print("="*70)

# Split data (80-20)
X_train, X_test, y_train, y_test = train_test_split(
    X_combined, y_combined, 
    test_size=0.2, 
    random_state=42, 
    stratify=y_combined
)

print(f"\nDataset split:")
print(f"  Training samples: {len(X_train):,}")
print(f"    - Green: {np.sum(y_train == 1):,} ({100*np.sum(y_train == 1)/len(y_train):.2f}%)")
print(f"    - Non-green: {np.sum(y_train == 0):,} ({100*np.sum(y_train == 0)/len(y_train):.2f}%)")
print(f"\n  Testing samples: {len(X_test):,}")
print(f"    - Green: {np.sum(y_test == 1):,} ({100*np.sum(y_test == 1)/len(y_test):.2f}%)")
print(f"    - Non-green: {np.sum(y_test == 0):,} ({100*np.sum(y_test == 0)/len(y_test):.2f}%)")

print(f"{'='*70}")

## 6. Train Random Forest Model
### Train on multi-city dataset

In [None]:
print("\n" + "="*70)
print("TRAINING RANDOM FOREST MODEL")
print("="*70)

# Initialize Random Forest
rf = RandomForestClassifier(
    n_estimators=100,
    max_depth=25,
    min_samples_split=50,
    min_samples_leaf=20,
    max_features='sqrt',
    random_state=42,
    n_jobs=-1,
    verbose=1
)

print(f"\nRandom Forest parameters:")
print(f"  n_estimators: {rf.n_estimators}")
print(f"  max_depth: {rf.max_depth}")
print(f"  min_samples_split: {rf.min_samples_split}")
print(f"  min_samples_leaf: {rf.min_samples_leaf}")
print(f"  max_features: {rf.max_features}")

print(f"\nTraining Random Forest...")
print(f"  Training on {len(complete_cities)} cities")
print(f"  Training samples: {len(X_train):,}")

rf.fit(X_train, y_train)

print(f"\n✓ Model trained successfully")
print(f"{'='*70}")

## 7. Evaluate Model Performance

In [None]:
print("\n" + "="*70)
print("MODEL EVALUATION")
print("="*70)

# Make predictions
y_pred = rf.predict(X_test)

# Calculate metrics
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred, zero_division=0)
recall = recall_score(y_test, y_pred, zero_division=0)
f1 = f1_score(y_test, y_pred, zero_division=0)

print(f"\nModel Performance (trained on {len(complete_cities)} cities):")
print(f"  Accuracy:  {accuracy:.4f}")
print(f"  Precision: {precision:.4f}")
print(f"  Recall:    {recall:.4f}")
print(f"  F1-Score:  {f1:.4f}")

# Confusion Matrix
cm = confusion_matrix(y_test, y_pred)
print(f"\nConfusion Matrix:")
print(f"                 Predicted")
print(f"               Non-Green  Green")
print(f"Actual Non-Green  {cm[0,0]:>8,}  {cm[0,1]:>8,}")
print(f"       Green      {cm[1,0]:>8,}  {cm[1,1]:>8,}")

# Save metrics
metrics = {
    "model": "RandomForest",
    "ground_truth": "WorldCover_2021",
    "training_cities": [city['name'] for city in complete_cities],
    "n_cities": len(complete_cities),
    "total_training_samples": int(len(X_train)),
    "total_testing_samples": int(len(X_test)),
    "accuracy": float(accuracy),
    "precision": float(precision),
    "recall": float(recall),
    "f1_score": float(f1),
    "confusion_matrix": cm.tolist()
}

with open(os.path.join(RUN_FOLDER, "metrics.json"), "w") as f:
    json.dump(metrics, f, indent=2)

print(f"\n✓ Metrics saved to: {RUN_FOLDER}/metrics.json")
print(f"{'='*70}")

## 8. Visualize Confusion Matrix

In [None]:
# Plot confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['Non-Green', 'Green'],
            yticklabels=['Non-Green', 'Green'],
            cbar_kws={'label': 'Count'})
plt.title(f'Confusion Matrix - Multi-City Random Forest\n(Trained on {len(complete_cities)} cities)', 
          fontsize=14, fontweight='bold')
plt.ylabel('True Label', fontsize=12)
plt.xlabel('Predicted Label', fontsize=12)
plt.tight_layout()
plt.savefig(os.path.join(RUN_FOLDER, 'confusion_matrix.png'), dpi=300, bbox_inches='tight')
plt.show()

print("✓ Confusion matrix saved")

## 9. Feature Importance Analysis

In [None]:
# Get feature importances
importances = rf.feature_importances_
n_features = len(importances)

# Generate band names based on actual number of bands
if n_features == 21:
    # 21 bands: 4 spectral + 3 indices per month × 3 months
    band_names = [
        'B02-Apr', 'B03-Apr', 'B04-Apr', 'B08-Apr', 'NDVI-Apr', 'EVI-Apr', 'SAVI-Apr',
        'B02-Aug', 'B03-Aug', 'B04-Aug', 'B08-Aug', 'NDVI-Aug', 'EVI-Aug', 'SAVI-Aug',
        'B02-Nov', 'B03-Nov', 'B04-Nov', 'B08-Nov', 'NDVI-Nov', 'EVI-Nov', 'SAVI-Nov'
    ]
elif n_features == 12:
    # 12 bands: 4 spectral per month × 3 months (no indices)
    band_names = [
        'B02-Apr', 'B03-Apr', 'B04-Apr', 'B08-Apr',
        'B02-Aug', 'B03-Aug', 'B04-Aug', 'B08-Aug',
        'B02-Nov', 'B03-Nov', 'B04-Nov', 'B08-Nov'
    ]
elif n_features == 14:
    # 14 bands: possibly 4 spectral + some indices per month
    band_names = [f'Band_{i+1}' for i in range(n_features)]
    print(f"ℹ Using generic band names for {n_features} bands")
else:
    # Generic fallback
    band_names = [f'Band_{i+1}' for i in range(n_features)]
    print(f"ℹ Using generic band names for {n_features} bands")

# Sort by importance
indices = np.argsort(importances)[::-1]

# Plot feature importances
plt.figure(figsize=(12, max(8, n_features * 0.4)))
plt.barh(range(len(importances)), importances[indices], color='steelblue')
plt.yticks(range(len(importances)), [band_names[i] for i in indices])
plt.xlabel('Feature Importance', fontsize=12)
plt.title(f'Random Forest Feature Importance\n(Trained on {len(all_X)} cities, {n_features} bands)', 
          fontsize=14, fontweight='bold')
plt.grid(axis='x', alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(RUN_FOLDER, 'feature_importance.png'), dpi=300, bbox_inches='tight')
plt.show()

print("✓ Feature importance plot saved")
print(f"\nTop 10 most important features:")
for i in range(min(10, len(importances))):
    idx = indices[i]
    print(f"  {i+1:2d}. {band_names[idx]:12s}: {importances[idx]:.4f}")

## 10. Save Trained Model

In [None]:
import joblib

# Save the trained model
model_file = os.path.join(RUN_FOLDER, 'random_forest_model.pkl')
joblib.dump(rf, model_file)

print(f"✓ Model saved to: {model_file}")
print(f"\nTo load the model later:")
print(f"  import joblib")
print(f"  rf = joblib.load('{model_file}')")

## 11. Per-City Performance Analysis (Optional)
### Evaluate model performance on each city individually

In [None]:
print("\n" + "="*70)
print("PER-CITY PERFORMANCE ANALYSIS")
print("="*70)

per_city_results = []

for city_data in complete_cities:
    city_name = city_data["name"]
    stack_file = city_data["stack_file"]
    worldcover_file = city_data["worldcover_file"]
    
    print(f"\n{city_name}:")
    
    try:
        # Load city data
        with rasterio.open(stack_file) as src:
            X_stack = src.read()
            stack_transform = src.transform
            stack_shape = (src.height, src.width)
            stack_crs = src.crs
        
        with rasterio.open(worldcover_file) as src:
            worldcover_data = np.empty(stack_shape, dtype=np.uint8)
            reproject(
                source=rasterio.band(src, 1),
                destination=worldcover_data,
                src_transform=src.transform,
                src_crs=src.crs,
                dst_transform=stack_transform,
                dst_crs=stack_crs,
                resampling=Resampling.nearest
            )
        
        labels = np.isin(worldcover_data, GREEN_CLASSES).astype(np.uint8)
        
        # Reshape and clean
        X = X_stack.reshape(X_stack.shape[0], -1).T
        y = labels.flatten()
        valid_mask = ~np.isnan(X).any(axis=1)
        X_city = X[valid_mask]
        y_city = y[valid_mask]
        
        # Predict
        y_pred_city = rf.predict(X_city)
        
        # Calculate metrics
        acc = accuracy_score(y_city, y_pred_city)
        prec = precision_score(y_city, y_pred_city, zero_division=0)
        rec = recall_score(y_city, y_pred_city, zero_division=0)
        f1_city = f1_score(y_city, y_pred_city, zero_division=0)
        
        print(f"  Accuracy:  {acc:.4f}")
        print(f"  Precision: {prec:.4f}")
        print(f"  Recall:    {rec:.4f}")
        print(f"  F1-Score:  {f1_city:.4f}")
        
        per_city_results.append({
            "city": city_name,
            "accuracy": float(acc),
            "precision": float(prec),
            "recall": float(rec),
            "f1_score": float(f1_city)
        })
        
    except Exception as e:
        print(f"  Error: {e}")

# Save per-city results
with open(os.path.join(RUN_FOLDER, "per_city_metrics.json"), "w") as f:
    json.dump(per_city_results, f, indent=2)

print(f"\n{'='*70}")
print(f"✓ Per-city metrics saved")

## 12. Summary Report

In [None]:
print("\n" + "="*80)
print("MULTI-CITY TRAINING - SUMMARY REPORT")
print("="*80)

print(f"\nGround Truth: WorldCover 2021")
print(f"Green Classes: Tree cover (10), Shrubland (20), Grassland (30), Mangroves (95)")

print(f"\nTraining Data:")
print(f"  Cities: {len(complete_cities)}")
for city in complete_cities:
    print(f"    - {city['name']}")

print(f"\n  Total training samples: {len(X_train):,}")
print(f"  Total testing samples:  {len(X_test):,}")

print(f"\nModel Performance (Overall):")
print(f"  Accuracy:  {accuracy:.4f}")
print(f"  Precision: {precision:.4f}")
print(f"  Recall:    {recall:.4f}")
print(f"  F1-Score:  {f1:.4f}")

if per_city_results:
    print(f"\nPer-City Performance (Average):")
    avg_acc = np.mean([r['accuracy'] for r in per_city_results])
    avg_prec = np.mean([r['precision'] for r in per_city_results])
    avg_rec = np.mean([r['recall'] for r in per_city_results])
    avg_f1 = np.mean([r['f1_score'] for r in per_city_results])
    print(f"  Accuracy:  {avg_acc:.4f}")
    print(f"  Precision: {avg_prec:.4f}")
    print(f"  Recall:    {avg_rec:.4f}")
    print(f"  F1-Score:  {avg_f1:.4f}")

print(f"\nOutput Files:")
print(f"  Results folder: {RUN_FOLDER}")
print(f"  - metrics.json (overall performance)")
print(f"  - per_city_metrics.json (individual city performance)")
print(f"  - confusion_matrix.png")
print(f"  - feature_importance.png")
print(f"  - random_forest_model.pkl (trained model)")

print(f"\n" + "="*80)
print(f"✓ TRAINING COMPLETE!")
print(f"="*80)