# Multi-City Green Space Detection
## Training U-Net with WorldCover 2021 as Ground Truth

Training cities: 11 cities for robust model generalisation

- Uses WorldCover 2021 as ground truth (green classes: tree cover, shrubland, grassland, mangroves)
- Multi-temporal Sentinel-2 data (April, August, November)
- 21 bands: 4 spectral + 3 vegetation indices per month
- Encoder-decoder with skip connections for pixel-wise segmentation
- Patch-based training (64x64 patches with augmentation)

## 1. Import Libraries

In [2]:
%pip install tensorflow rasterio scikit-learn matplotlib seaborn tqdm

import json
import os
import numpy as np
import rasterio
from rasterio.warp import reproject, Resampling
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# TensorFlow/Keras
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Model
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU available: {tf.config.list_physical_devices('GPU')}")
print("Libraries imported successfully")

Collecting tensorflow
  Downloading tensorflow-2.20.0-cp313-cp313-macosx_12_0_arm64.whl.metadata (4.5 kB)
Collecting absl-py>=1.0.0 (from tensorflow)
  Using cached absl_py-2.3.1-py3-none-any.whl.metadata (3.3 kB)
Collecting astunparse>=1.6.0 (from tensorflow)
  Downloading astunparse-1.6.3-py2.py3-none-any.whl.metadata (4.4 kB)
Collecting flatbuffers>=24.3.25 (from tensorflow)
  Downloading flatbuffers-25.12.19-py2.py3-none-any.whl.metadata (1.0 kB)
Collecting gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 (from tensorflow)
  Downloading gast-0.7.0-py3-none-any.whl.metadata (1.5 kB)
Collecting google_pasta>=0.1.1 (from tensorflow)
  Downloading google_pasta-0.2.0-py3-none-any.whl.metadata (814 bytes)
Collecting libclang>=13.0.0 (from tensorflow)
  Downloading libclang-18.1.1-1-py2.py3-none-macosx_11_0_arm64.whl.metadata (5.2 kB)
Collecting opt_einsum>=2.3.2 (from tensorflow)
  Downloading opt_einsum-3.4.0-py3-none-any.whl.metadata (6.3 kB)
Collecting protobuf>=5.28.0 (from tensorflow)
  Download

## 2. Configuration

In [None]:
# Base paths - using relative paths from project root
# Run notebooks from the project root directory: python -m jupyter notebook
import os

# Find project root (go up from notebooks/training/)
if os.path.exists("data") and os.path.exists("models"):
    PROJECT_ROOT = os.getcwd()
elif os.path.exists("../../data") and os.path.exists("../../models"):
    PROJECT_ROOT = os.path.abspath("../..")
else:
    PROJECT_ROOT = os.getcwd()
    print(f"Warning: Could not detect project root. Using: {PROJECT_ROOT}")

# Derived paths
DATA_PATH = os.path.join(PROJECT_ROOT, "data")
MODELS_PATH = os.path.join(PROJECT_ROOT, "models")

# Output folder
OUTPUT_FOLDER = os.path.join(PROJECT_ROOT, "outputs", "unet_training")
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
os.makedirs(MODELS_PATH, exist_ok=True)

# Create timestamped run folder
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
RUN_FOLDER = os.path.join(OUTPUT_FOLDER, f"run_{timestamp}")
os.makedirs(RUN_FOLDER, exist_ok=True)

# WorldCover green classes
GREEN_CLASSES = [10, 20, 30, 95]  # Tree, Shrub, Grass, Mangroves

# U-Net Configuration
PATCH_SIZE = 64  # Size of image patches (64x64)
BATCH_SIZE = 32
EPOCHS = 50
LEARNING_RATE = 0.001
N_FILTERS_START = 32  # Starting number of filters

# Define cities with their specific file locations
CITY_FILES = {
    "Amsterdam": {
        "stack": os.path.join(DATA_PATH, "sentinel_stacks", "Amsterdam_MultiMonth_stack.tif"),
        "worldcover": os.path.join(DATA_PATH, "worldcover", "Amsterdam_WorldCover_2021.tif"),
    },
    "Auckland": {
        "stack": os.path.join(DATA_PATH, "sentinel_stacks", "Auckland_MultiMonth_stack.tif"),
        "worldcover": os.path.join(DATA_PATH, "worldcover", "Auckland_WorldCover_2021.tif"),
    },
    "Barcelona": {
        "stack": os.path.join(DATA_PATH, "sentinel_stacks", "Barcelona_MultiMonth_stack.tif"),
        "worldcover": os.path.join(DATA_PATH, "worldcover", "Barcelona_WorldCover_2021.tif"),
    },
    "Sydney": {
        "stack": os.path.join(DATA_PATH, "sentinel_stacks", "Sydney_MultiMonth_stack.tif"),
        "worldcover": os.path.join(DATA_PATH, "worldcover", "Sydney_WorldCover_2021.tif"),
    },
    "Toronto": {
        "stack": os.path.join(DATA_PATH, "sentinel_stacks", "Toronto_MultiMonth_stack.tif"),
        "worldcover": os.path.join(DATA_PATH, "worldcover", "Toronto_WorldCover_2021.tif"),
    },
    "Vienna": {
        "stack": os.path.join(DATA_PATH, "sentinel_stacks", "Wien_MultiMonth_stack.tif"),
        "worldcover": os.path.join(DATA_PATH, "worldcover", "Vienna_WorldCover_2021.tif"),
    },
    "London": {
        "stack": os.path.join(DATA_PATH, "sentinel_stacks", "London_MultiMonth_stack.tif"),
        "worldcover": os.path.join(DATA_PATH, "worldcover", "London_WorldCover_2021.tif"),
    },
    "Melbourne": {
        "stack": os.path.join(DATA_PATH, "sentinel_stacks", "Melbourne_MultiMonth_stack.tif"),
        "worldcover": os.path.join(DATA_PATH, "worldcover", "Melbourne_WorldCover_2021.tif"),
    },
    "Paris": {
        "stack": os.path.join(DATA_PATH, "sentinel_stacks", "Paris_MultiMonth_stack.tif"),
        "worldcover": os.path.join(DATA_PATH, "worldcover", "Paris_WorldCover_2021.tif"),
    },
    "San_Francisco": {
        "stack": os.path.join(DATA_PATH, "sentinel_stacks", "San_Francisco_MultiMonth_stack.tif"),
        "worldcover": os.path.join(DATA_PATH, "worldcover", "San_Francisco_WorldCover_2021.tif"),
    },
    "Seattle": {
        "stack": os.path.join(DATA_PATH, "sentinel_stacks", "Seattle_MultiMonth_stack.tif"),
        "worldcover": os.path.join(DATA_PATH, "worldcover", "Seattle_WorldCover_2021.tif"),
    },
}

print("Configuration loaded")
print(f"  Project root: {PROJECT_ROOT}")
print(f"  Data path: {DATA_PATH}")
print(f"  Models path: {MODELS_PATH}")
print(f"  Output folder: {RUN_FOLDER}")
print(f"  Patch size: {PATCH_SIZE}x{PATCH_SIZE}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Epochs: {EPOCHS}")
print(f"  Target cities: {len(CITY_FILES)}")

## 3. Discover Available Cities

In [4]:
print("="*70)
print("DISCOVERING AVAILABLE CITIES")
print("="*70)

cities_data = []

for city_name, paths in CITY_FILES.items():
    stack_file = paths["stack"]
    worldcover_file = paths["worldcover"]
    
    has_stack = os.path.exists(stack_file)
    has_worldcover = os.path.exists(worldcover_file)
    
    status_stack = "Y" if has_stack else "N"
    status_worldcover = "Y" if has_worldcover else "N"
    
    print(f"  {city_name:15s} - Stack: {status_stack}  WorldCover: {status_worldcover}")
    
    if has_stack and has_worldcover:
        cities_data.append({
            "name": city_name,
            "stack_file": stack_file,
            "worldcover_file": worldcover_file
        })

complete_cities = cities_data

print(f"\nCities with complete data: {len(complete_cities)}/{len(CITY_FILES)}")

if len(complete_cities) == 0:
    raise ValueError("No cities with complete data found!")

DISCOVERING AVAILABLE CITIES
  Amsterdam       - Stack: Y  WorldCover: Y
  Auckland        - Stack: Y  WorldCover: Y
  Barcelona       - Stack: Y  WorldCover: Y
  Sydney          - Stack: Y  WorldCover: Y
  Toronto         - Stack: Y  WorldCover: Y
  Vienna          - Stack: Y  WorldCover: Y
  Melbourne       - Stack: Y  WorldCover: Y
  Paris           - Stack: Y  WorldCover: Y
  San_Francisco   - Stack: Y  WorldCover: Y
  Seattle         - Stack: Y  WorldCover: Y

Cities with complete data: 10/10


## 4. Helper Functions for Patch Extraction

In [5]:
def extract_patches(image, mask, patch_size, stride=None):
    """
    Extract patches from image and mask.
    
    Args:
        image: numpy array of shape (H, W, C)
        mask: numpy array of shape (H, W)
        patch_size: size of patches to extract
        stride: stride between patches (default: patch_size // 2 for overlap)
    
    Returns:
        patches_X: list of image patches
        patches_y: list of mask patches
    """
    if stride is None:
        stride = patch_size // 2
    
    H, W, C = image.shape
    patches_X = []
    patches_y = []
    
    for i in range(0, H - patch_size + 1, stride):
        for j in range(0, W - patch_size + 1, stride):
            patch_X = image[i:i+patch_size, j:j+patch_size, :]
            patch_y = mask[i:i+patch_size, j:j+patch_size]
            
            # Skip patches with NaN values
            if np.isnan(patch_X).any():
                continue
            
            patches_X.append(patch_X)
            patches_y.append(patch_y)
    
    return patches_X, patches_y


def normalize_image(image):
    """
    Normalize image to [0, 1] range per channel.
    """
    image = image.astype(np.float32)
    for c in range(image.shape[-1]):
        channel = image[:, :, c]
        min_val = np.nanmin(channel)
        max_val = np.nanmax(channel)
        if max_val > min_val:
            image[:, :, c] = (channel - min_val) / (max_val - min_val)
        else:
            image[:, :, c] = 0
    return image


def augment_patch(patch_X, patch_y):
    """
    Apply random augmentations to a patch.
    """
    augmented_X = [patch_X]
    augmented_y = [patch_y]
    
    # Horizontal flip
    augmented_X.append(np.fliplr(patch_X))
    augmented_y.append(np.fliplr(patch_y))
    
    # Vertical flip
    augmented_X.append(np.flipud(patch_X))
    augmented_y.append(np.flipud(patch_y))
    
    # 90 degree rotation
    augmented_X.append(np.rot90(patch_X, k=1))
    augmented_y.append(np.rot90(patch_y, k=1))
    
    return augmented_X, augmented_y


print("Helper functions defined")

Helper functions defined


## 5. Load and Create Patches from All Cities

In [6]:
print("\n" + "="*70)
print("LOADING AND CREATING PATCHES")
print("="*70)

all_patches_X = []
all_patches_y = []
n_bands = None

for city_data in tqdm(complete_cities, desc="Processing cities"):
    city_name = city_data["name"]
    stack_file = city_data["stack_file"]
    worldcover_file = city_data["worldcover_file"]
    
    print(f"\nProcessing: {city_name}")
    
    try:
        # Load Sentinel-2 stack
        with rasterio.open(stack_file) as src:
            X_stack = src.read()  # (bands, H, W)
            stack_transform = src.transform
            stack_shape = (src.height, src.width)
            stack_crs = src.crs
        
        if n_bands is None:
            n_bands = X_stack.shape[0]
            print(f"  Setting n_bands to {n_bands}")
        elif X_stack.shape[0] != n_bands:
            print(f"  SKIPPING: Band mismatch ({X_stack.shape[0]} vs {n_bands})")
            continue
        
        # Transpose to (H, W, C) for easier handling
        X_image = X_stack.transpose(1, 2, 0)  # (H, W, bands)
        
        # Load and reproject WorldCover
        with rasterio.open(worldcover_file) as src:
            worldcover_data = np.empty(stack_shape, dtype=np.uint8)
            reproject(
                source=rasterio.band(src, 1),
                destination=worldcover_data,
                src_transform=src.transform,
                src_crs=src.crs,
                dst_transform=stack_transform,
                dst_crs=stack_crs,
                resampling=Resampling.nearest
            )
        
        # Create binary mask
        mask = np.isin(worldcover_data, GREEN_CLASSES).astype(np.float32)
        
        # Normalize image
        X_norm = normalize_image(X_image)
        
        # Extract patches
        patches_X, patches_y = extract_patches(
            X_norm, mask, 
            patch_size=PATCH_SIZE, 
            stride=PATCH_SIZE // 2
        )
        
        print(f"  Extracted {len(patches_X)} patches")
        
        # Apply augmentation
        for px, py in zip(patches_X, patches_y):
            aug_X, aug_y = augment_patch(px, py)
            all_patches_X.extend(aug_X)
            all_patches_y.extend(aug_y)
        
        print(f"  After augmentation: {len(all_patches_X)} total patches")
        
    except Exception as e:
        print(f"  Error: {e}")
        continue

# Convert to numpy arrays
X_all = np.array(all_patches_X, dtype=np.float32)
y_all = np.array(all_patches_y, dtype=np.float32)

# Expand mask dimensions for U-Net output
y_all = np.expand_dims(y_all, axis=-1)

print(f"\n{'='*70}")
print(f"DATA SUMMARY")
print(f"{'='*70}")
print(f"  Total patches: {len(X_all):,}")
print(f"  Patch shape: {X_all.shape[1:]}")
print(f"  Mask shape: {y_all.shape[1:]}")
print(f"  Memory: {X_all.nbytes / 1e9:.2f} GB")


LOADING AND CREATING PATCHES


Processing cities:  30%|███       | 3/10 [00:00<00:00, 22.74it/s]


Processing: Amsterdam
  Setting n_bands to 21
  Extracted 37 patches
  After augmentation: 148 total patches

Processing: Auckland
  Extracted 71 patches
  After augmentation: 432 total patches

Processing: Barcelona
  Extracted 139 patches
  After augmentation: 988 total patches

Processing: Sydney
  Extracted 87 patches
  After augmentation: 1336 total patches

Processing: Toronto


Processing cities:  60%|██████    | 6/10 [00:00<00:00, 25.55it/s]

  Extracted 78 patches
  After augmentation: 1648 total patches

Processing: Vienna
  Extracted 52 patches
  After augmentation: 1856 total patches

Processing: Melbourne
  Extracted 38 patches
  After augmentation: 2008 total patches

Processing: Paris
  Extracted 20 patches
  After augmentation: 2088 total patches

Processing: San_Francisco
  Extracted 18 patches
  After augmentation: 2160 total patches

Processing: Seattle


Processing cities: 100%|██████████| 10/10 [00:00<00:00, 26.66it/s]

  Extracted 62 patches
  After augmentation: 2408 total patches






DATA SUMMARY
  Total patches: 2,408
  Patch shape: (64, 64, 21)
  Mask shape: (64, 64, 1)
  Memory: 0.83 GB


## 6. Train-Test Split

In [7]:
print("\n" + "="*70)
print("TRAIN-TEST SPLIT")
print("="*70)

# Split data (80-20)
X_train, X_test, y_train, y_test = train_test_split(
    X_all, y_all,
    test_size=0.2,
    random_state=42
)

# Further split training into train/validation (80-20 of training)
X_train, X_val, y_train, y_val = train_test_split(
    X_train, y_train,
    test_size=0.2,
    random_state=42
)

print(f"\nDataset split:")
print(f"  Training:   {len(X_train):,} patches")
print(f"  Validation: {len(X_val):,} patches")
print(f"  Testing:    {len(X_test):,} patches")

# Calculate class balance
train_green_pct = 100 * y_train.mean()
print(f"\nGreen pixel percentage in training: {train_green_pct:.1f}%")


TRAIN-TEST SPLIT

Dataset split:
  Training:   1,540 patches
  Validation: 386 patches
  Testing:    482 patches

Green pixel percentage in training: 15.9%


## 7. Build U-Net Model

In [8]:
def conv_block(inputs, n_filters, kernel_size=3):
    """
    Convolutional block with two conv layers, batch norm, and ReLU.
    """
    x = layers.Conv2D(n_filters, kernel_size, padding='same')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    
    x = layers.Conv2D(n_filters, kernel_size, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    
    return x


def encoder_block(inputs, n_filters):
    """
    Encoder block: conv_block + max pooling.
    """
    x = conv_block(inputs, n_filters)
    p = layers.MaxPooling2D(pool_size=(2, 2))(x)
    return x, p


def decoder_block(inputs, skip_features, n_filters):
    """
    Decoder block: upsample + concatenate skip connection + conv_block.
    """
    x = layers.Conv2DTranspose(n_filters, (2, 2), strides=2, padding='same')(inputs)
    x = layers.Concatenate()([x, skip_features])
    x = conv_block(x, n_filters)
    return x


def build_unet(input_shape, n_classes=1, n_filters_start=32):
    """
    Build U-Net model.
    
    Args:
        input_shape: tuple (H, W, C)
        n_classes: number of output classes (1 for binary)
        n_filters_start: number of filters in first layer
    
    Returns:
        keras Model
    """
    inputs = layers.Input(shape=input_shape)
    
    # Encoder (downsampling path)
    s1, p1 = encoder_block(inputs, n_filters_start)       # 64 -> 32
    s2, p2 = encoder_block(p1, n_filters_start * 2)       # 32 -> 16
    s3, p3 = encoder_block(p2, n_filters_start * 4)       # 16 -> 8
    s4, p4 = encoder_block(p3, n_filters_start * 8)       # 8 -> 4
    
    # Bridge
    b = conv_block(p4, n_filters_start * 16)              # 4x4
    
    # Decoder (upsampling path)
    d1 = decoder_block(b, s4, n_filters_start * 8)        # 4 -> 8
    d2 = decoder_block(d1, s3, n_filters_start * 4)       # 8 -> 16
    d3 = decoder_block(d2, s2, n_filters_start * 2)       # 16 -> 32
    d4 = decoder_block(d3, s1, n_filters_start)           # 32 -> 64
    
    # Output
    if n_classes == 1:
        outputs = layers.Conv2D(1, (1, 1), activation='sigmoid')(d4)
    else:
        outputs = layers.Conv2D(n_classes, (1, 1), activation='softmax')(d4)
    
    model = Model(inputs, outputs, name='UNet')
    return model


# Build model
input_shape = (PATCH_SIZE, PATCH_SIZE, n_bands)
model = build_unet(input_shape, n_classes=1, n_filters_start=N_FILTERS_START)

print(f"\nU-Net Model Summary:")
print(f"  Input shape: {input_shape}")
print(f"  Starting filters: {N_FILTERS_START}")
print(f"  Total parameters: {model.count_params():,}")

model.summary()


U-Net Model Summary:
  Input shape: (64, 64, 21)
  Starting filters: 32
  Total parameters: 7,777,057


## 8. Compile and Train Model

In [None]:
print("\n" + "="*70)
print("COMPILING AND TRAINING MODEL")
print("="*70)

# Custom metrics
def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = tf.keras.backend.flatten(y_true)
    y_pred_f = tf.keras.backend.flatten(y_pred)
    intersection = tf.keras.backend.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (tf.keras.backend.sum(y_true_f) + tf.keras.backend.sum(y_pred_f) + smooth)

def dice_loss(y_true, y_pred):
    return 1 - dice_coef(y_true, y_pred)

def combined_loss(y_true, y_pred):
    return tf.keras.losses.binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)

# Compile model
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE),
    loss=combined_loss,
    metrics=['accuracy', dice_coef]
)

# Callbacks
callbacks = [
    EarlyStopping(
        monitor='val_loss',
        patience=10,
        restore_best_weights=True,
        verbose=1
    ),
    ModelCheckpoint(
        filepath=os.path.join(RUN_FOLDER, 'unet_best.keras'),
        monitor='val_loss',
        save_best_only=True,
        verbose=1
    ),
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=5,
        min_lr=1e-6,
        verbose=1
    )
]

print(f"\nTraining configuration:")
print(f"  Optimizer: Adam (lr={LEARNING_RATE})")
print(f"  Loss: Binary Crossentropy + Dice Loss")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Epochs: {EPOCHS}")

# Train model
print(f"\nStarting training...")
history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks=callbacks,
    verbose=1
)

print(f"\nTraining complete!")


COMPILING AND TRAINING MODEL

Training configuration:
  Optimizer: Adam (lr=0.001)
  Loss: Binary Crossentropy + Dice Loss
  Batch size: 32
  Epochs: 50

Starting training...
Epoch 1/50
[1m49/49[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 593ms/step - accuracy: 0.8820 - dice_coef: 0.5263 - loss: 0.8172
Epoch 1: val_loss improved from None to 1.08878, saving model to /Users/tyomachka/Desktop/WU/Data_Lab.TMP/rep.infrared.city/unet_training/run_20260127_221426/unet_best.keras

Epoch 1: finished saving model to /Users/tyomachka/Desktop/WU/Data_Lab.TMP/rep.infrared.city/unet_training/run_20260127_221426/unet_best.keras
[1m49/49[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m39s[0m 643ms/step - accuracy: 0.9182 - dice_coef: 0.6007 - loss: 0.6580 - val_accuracy: 0.9161 - val_dice_coef: 0.3450 - val_loss: 1.0888 - learning_rate: 0.0010
Epoch 2/50
[1m23/49[0m [32m━━━━━━━━━[0m[37m━━━━━━━━━━━[0m [1m15s[0m 614ms/step - accuracy: 0.9305 - dice_coef: 0.6611 - loss: 0.5427

## 9. Plot Training History

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Loss
axes[0].plot(history.history['loss'], label='Train')
axes[0].plot(history.history['val_loss'], label='Validation')
axes[0].set_title('Loss', fontsize=12, fontweight='bold')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend()
axes[0].grid(alpha=0.3)

# Accuracy
axes[1].plot(history.history['accuracy'], label='Train')
axes[1].plot(history.history['val_accuracy'], label='Validation')
axes[1].set_title('Accuracy', fontsize=12, fontweight='bold')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].legend()
axes[1].grid(alpha=0.3)

# Dice Coefficient
axes[2].plot(history.history['dice_coef'], label='Train')
axes[2].plot(history.history['val_dice_coef'], label='Validation')
axes[2].set_title('Dice Coefficient', fontsize=12, fontweight='bold')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Dice')
axes[2].legend()
axes[2].grid(alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(RUN_FOLDER, 'training_history.png'), dpi=300, bbox_inches='tight')
plt.show()

print("Training history plot saved")

## 10. Evaluate Model on Test Set

In [None]:
print("\n" + "="*70)
print("MODEL EVALUATION")
print("="*70)

# Load best model
best_model_path = os.path.join(RUN_FOLDER, 'unet_best.keras')
if os.path.exists(best_model_path):
    model = keras.models.load_model(
        best_model_path,
        custom_objects={'dice_coef': dice_coef, 'combined_loss': combined_loss}
    )
    print(f"Loaded best model from: {best_model_path}")

# Evaluate on test set
test_loss, test_acc, test_dice = model.evaluate(X_test, y_test, verbose=0)
print(f"\nTest Set Performance:")
print(f"  Loss: {test_loss:.4f}")
print(f"  Accuracy: {test_acc:.4f}")
print(f"  Dice Coefficient: {test_dice:.4f}")

# Get predictions
y_pred_proba = model.predict(X_test, verbose=0)
y_pred = (y_pred_proba > 0.5).astype(np.float32)

# Flatten for sklearn metrics
y_test_flat = y_test.flatten()
y_pred_flat = y_pred.flatten()

# Calculate metrics
accuracy = accuracy_score(y_test_flat, y_pred_flat)
precision = precision_score(y_test_flat, y_pred_flat, zero_division=0)
recall = recall_score(y_test_flat, y_pred_flat, zero_division=0)
f1 = f1_score(y_test_flat, y_pred_flat, zero_division=0)

print(f"\nPixel-wise Metrics:")
print(f"  Accuracy:  {accuracy:.4f}")
print(f"  Precision: {precision:.4f}")
print(f"  Recall:    {recall:.4f}")
print(f"  F1-Score:  {f1:.4f}")

# Confusion Matrix
cm = confusion_matrix(y_test_flat, y_pred_flat)
print(f"\nConfusion Matrix:")
print(f"                 Predicted")
print(f"               Non-Green  Green")
print(f"Actual Non-Green  {cm[0,0]:>10,}  {cm[0,1]:>10,}")
print(f"       Green      {cm[1,0]:>10,}  {cm[1,1]:>10,}")

## 11. Visualize Sample Predictions

In [None]:
# Visualize sample predictions
n_samples = 6
indices = np.random.choice(len(X_test), n_samples, replace=False)

fig, axes = plt.subplots(n_samples, 4, figsize=(16, n_samples * 3))

for i, idx in enumerate(indices):
    # RGB composite (assuming bands 2,1,0 are B04, B03, B02 from April)
    rgb = X_test[idx][:, :, [2, 1, 0]]  # B04, B03, B02
    rgb = np.clip(rgb * 3, 0, 1)  # Enhance brightness
    
    axes[i, 0].imshow(rgb)
    axes[i, 0].set_title('RGB Composite' if i == 0 else '')
    axes[i, 0].axis('off')
    
    axes[i, 1].imshow(y_test[idx, :, :, 0], cmap='Greens', vmin=0, vmax=1)
    axes[i, 1].set_title('Ground Truth' if i == 0 else '')
    axes[i, 1].axis('off')
    
    axes[i, 2].imshow(y_pred[idx, :, :, 0], cmap='Greens', vmin=0, vmax=1)
    axes[i, 2].set_title('Prediction' if i == 0 else '')
    axes[i, 2].axis('off')
    
    # Difference (False Positives in red, False Negatives in blue)
    diff = np.zeros((*y_test[idx, :, :, 0].shape, 3))
    fp = (y_pred[idx, :, :, 0] > y_test[idx, :, :, 0]).astype(float)  # False positive
    fn = (y_pred[idx, :, :, 0] < y_test[idx, :, :, 0]).astype(float)  # False negative
    correct = (y_pred[idx, :, :, 0] == y_test[idx, :, :, 0]).astype(float) * y_test[idx, :, :, 0]
    diff[:, :, 0] = fp  # Red for FP
    diff[:, :, 1] = correct  # Green for correct
    diff[:, :, 2] = fn  # Blue for FN
    
    axes[i, 3].imshow(diff)
    axes[i, 3].set_title('FP(R) / Correct(G) / FN(B)' if i == 0 else '')
    axes[i, 3].axis('off')

plt.tight_layout()
plt.savefig(os.path.join(RUN_FOLDER, 'sample_predictions.png'), dpi=300, bbox_inches='tight')
plt.show()

print("Sample predictions saved")

## 12. Save Model and Metrics

In [None]:
# Save final model
final_model_path = os.path.join(RUN_FOLDER, 'unet_model.keras')
model.save(final_model_path)
print(f"Final model saved to: {final_model_path}")

# Save metrics
metrics = {
    "model": "UNet",
    "ground_truth": "WorldCover_2021",
    "training_cities": [city['name'] for city in complete_cities],
    "n_cities": len(complete_cities),
    "patch_size": PATCH_SIZE,
    "n_filters_start": N_FILTERS_START,
    "total_patches": len(X_all),
    "training_patches": len(X_train),
    "validation_patches": len(X_val),
    "testing_patches": len(X_test),
    "test_accuracy": float(accuracy),
    "test_precision": float(precision),
    "test_recall": float(recall),
    "test_f1_score": float(f1),
    "test_dice": float(test_dice),
    "confusion_matrix": cm.tolist(),
    "epochs_trained": len(history.history['loss']),
    "final_train_loss": float(history.history['loss'][-1]),
    "final_val_loss": float(history.history['val_loss'][-1])
}

with open(os.path.join(RUN_FOLDER, "metrics.json"), "w") as f:
    json.dump(metrics, f, indent=2)

print(f"Metrics saved to: {RUN_FOLDER}/metrics.json")

# Save normalization parameters for inference
norm_params = {
    "method": "per_channel_minmax",
    "range": [0, 1],
    "n_bands": n_bands
}
with open(os.path.join(RUN_FOLDER, "normalization_params.json"), "w") as f:
    json.dump(norm_params, f, indent=2)

print(f"Normalization params saved to: {RUN_FOLDER}/normalization_params.json")

## 13. Confusion Matrix Visualization

In [None]:
# Plot confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['Non-Green', 'Green'],
            yticklabels=['Non-Green', 'Green'],
            cbar_kws={'label': 'Count'})
plt.title(f'Confusion Matrix - U-Net\n(Trained on {len(complete_cities)} cities)', 
          fontsize=14, fontweight='bold')
plt.ylabel('True Label', fontsize=12)
plt.xlabel('Predicted Label', fontsize=12)
plt.tight_layout()
plt.savefig(os.path.join(RUN_FOLDER, 'confusion_matrix.png'), dpi=300, bbox_inches='tight')
plt.show()

print("Confusion matrix saved")

## 14. Summary Report

In [None]:
print("\n" + "="*80)
print("U-NET TRAINING - SUMMARY REPORT")
print("="*80)

print(f"\nGround Truth: WorldCover 2021")
print(f"Green Classes: Tree cover (10), Shrubland (20), Grassland (30), Mangroves (95)")

print(f"\nU-Net Configuration:")
print(f"  Input shape: {input_shape}")
print(f"  Patch size: {PATCH_SIZE}x{PATCH_SIZE}")
print(f"  Starting filters: {N_FILTERS_START}")
print(f"  Total parameters: {model.count_params():,}")

print(f"\nTraining Configuration:")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Epochs trained: {len(history.history['loss'])}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Loss: Binary Crossentropy + Dice Loss")

print(f"\nTraining Data:")
print(f"  Cities: {len(complete_cities)}")
for city in complete_cities:
    print(f"    - {city['name']}")

print(f"\n  Total patches: {len(X_all):,}")
print(f"  Training patches: {len(X_train):,}")
print(f"  Validation patches: {len(X_val):,}")
print(f"  Testing patches: {len(X_test):,}")

print(f"\nModel Performance (Test Set):")
print(f"  Accuracy:  {accuracy:.4f}")
print(f"  Precision: {precision:.4f}")
print(f"  Recall:    {recall:.4f}")
print(f"  F1-Score:  {f1:.4f}")
print(f"  Dice:      {test_dice:.4f}")

print(f"\nOutput Files:")
print(f"  Results folder: {RUN_FOLDER}")
print(f"  - unet_model.keras (final model)")
print(f"  - unet_best.keras (best validation model)")
print(f"  - metrics.json")
print(f"  - normalization_params.json")
print(f"  - training_history.png")
print(f"  - sample_predictions.png")
print(f"  - confusion_matrix.png")

print(f"\n" + "="*80)
print(f"TRAINING COMPLETE!")
print(f"="*80)

## 15. Copy Model to Project Root (Optional)

In [None]:
import shutil

# Copy model to main models folder
src_model = os.path.join(RUN_FOLDER, 'unet_model.keras')
dst_model = os.path.join(MODELS_PATH, 'unet_model.keras')

shutil.copy(src_model, dst_model)

# Also copy normalization params
src_norm = os.path.join(RUN_FOLDER, 'normalization_params.json')
dst_norm = os.path.join(MODELS_PATH, 'unet_normalization_params.json')
shutil.copy(src_norm, dst_norm)

print(f"Model copied to: {dst_model}")
print(f"Normalization params copied to: {dst_norm}")

## 16. Inference Function for Full Images

Use this function to apply the trained U-Net to full-size images.

In [None]:
def predict_full_image(model, image, patch_size=64, overlap=16):
    """
    Apply U-Net to a full image using sliding window with overlap.
    
    Args:
        model: trained U-Net model
        image: numpy array of shape (H, W, C)
        patch_size: size of patches
        overlap: overlap between patches for smoother output
    
    Returns:
        prediction: numpy array of shape (H, W)
    """
    H, W, C = image.shape
    stride = patch_size - overlap
    
    # Pad image if necessary
    pad_h = (patch_size - H % stride) % stride
    pad_w = (patch_size - W % stride) % stride
    image_padded = np.pad(image, ((0, pad_h), (0, pad_w), (0, 0)), mode='reflect')
    
    H_pad, W_pad, _ = image_padded.shape
    
    # Initialize output arrays
    prediction_sum = np.zeros((H_pad, W_pad), dtype=np.float32)
    count = np.zeros((H_pad, W_pad), dtype=np.float32)
    
    # Normalize image
    image_norm = normalize_image(image_padded)
    
    # Sliding window prediction
    for i in range(0, H_pad - patch_size + 1, stride):
        for j in range(0, W_pad - patch_size + 1, stride):
            patch = image_norm[i:i+patch_size, j:j+patch_size, :]
            patch_batch = np.expand_dims(patch, axis=0)
            
            pred = model.predict(patch_batch, verbose=0)[0, :, :, 0]
            
            prediction_sum[i:i+patch_size, j:j+patch_size] += pred
            count[i:i+patch_size, j:j+patch_size] += 1
    
    # Average overlapping predictions
    prediction = prediction_sum / np.maximum(count, 1)
    
    # Remove padding
    prediction = prediction[:H, :W]
    
    return prediction


print("Inference function defined")
print("\nUsage:")
print("  prediction = predict_full_image(model, image, patch_size=64, overlap=16)")
print("  binary_mask = (prediction > 0.5).astype(np.uint8)")