# üåç ISMN Validation Results: SMPS Water Balance Model

This notebook visualizes the validation results of the **SMPS (Soil Moisture Prediction System) Physics-Based Water Balance Model** against **ISMN (International Soil Moisture Network)** in-situ observations from the **TAHMO network** in Kenya and Ghana.

## Overview
- **Time Period**: January 2020 - December 2021 (2 years)
- **Data Sources**: ISMN TAHMO stations, Open-Meteo weather, iSDA soil data, MODIS NDVI
- **Regions**: East Africa (Kenya) and West Africa (Ghana)
- **Measurement Depths**: 10cm, 20cm, 30cm, 60cm

In [None]:
# Import Required Libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Set plotting style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 11

print("‚úÖ Libraries loaded successfully!")

## 1. Load Validation Results

Load the summary and daily results files generated by the validation script.

In [None]:
# Load the validation results
data_dir = Path('/home/viv/SMPS/data/ismn')

# Load summary file with site-level metrics
summary_df = pd.read_csv(data_dir / 'ismn_validation_summary_multidepth.csv')

# Load daily results with actual predictions
daily_df = pd.read_csv(data_dir / 'ismn_daily_results_multidepth.csv')
daily_df['date'] = pd.to_datetime(daily_df['date'])

print(f"üìä Loaded validation results:")
print(f"   ‚Ä¢ {len(summary_df)} sites validated")
print(f"   ‚Ä¢ {len(daily_df):,} daily observations")
print(f"   ‚Ä¢ Date range: {daily_df['date'].min().date()} to {daily_df['date'].max().date()}")

# Display summary dataframe
summary_df[['station', 'country', 'depth_m', 'n', 'mae', 'rmse', 'r', 'bias']].round(4)

## 2. üìà Overall Validation Metrics Summary

Key performance indicators across all validated sites:

In [None]:
# Calculate overall metrics
fig, axes = plt.subplots(1, 4, figsize=(16, 5))

metrics = ['mae', 'rmse', 'r', 'bias']
titles = ['Mean Absolute Error\n(MAE)', 'Root Mean Square Error\n(RMSE)', 'Correlation\n(r)', 'Bias']
colors = ['#e74c3c', '#e67e22', '#2ecc71', '#3498db']

for ax, metric, title, color in zip(axes, metrics, titles, colors):
    values = summary_df[metric].dropna()
    mean_val = values.mean()

    # Create box plot with individual points
    bp = ax.boxplot(values, patch_artist=True, widths=0.5)
    bp['boxes'][0].set_facecolor(color)
    bp['boxes'][0].set_alpha(0.3)

    # Scatter individual sites
    ax.scatter(np.ones(len(values)) + np.random.normal(0, 0.04, len(values)),
               values, alpha=0.6, color=color, s=60, edgecolor='white', linewidth=1)

    ax.axhline(mean_val, color=color, linestyle='--', linewidth=2, label=f'Mean: {mean_val:.4f}')
    ax.set_title(title, fontsize=12, fontweight='bold')
    ax.set_ylabel('Value (m¬≥/m¬≥)' if metric != 'r' else 'Correlation coefficient')
    ax.legend()
    ax.set_xticks([])

plt.suptitle('üìä Overall Validation Performance Across 22 Sites', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

# Print summary statistics
print("\n" + "="*70)
print("                    OVERALL VALIDATION METRICS")
print("="*70)
print(f"  üìç Total sites validated:      {len(summary_df)}")
print(f"  üìÖ Total observation days:     {summary_df['n'].sum():,}")
print(f"  ‚îÄ"*35)
print(f"  üìâ Mean MAE:                   {summary_df['mae'].mean():.4f} m¬≥/m¬≥")
print(f"  üìâ Mean RMSE:                  {summary_df['rmse'].mean():.4f} m¬≥/m¬≥")
print(f"  üìà Mean Correlation (r):       {summary_df['r'].mean():.3f}")
print(f"  üìä Mean Bias:                  {summary_df['bias'].mean():.4f} m¬≥/m¬≥")
print("="*70)

## 3. üó∫Ô∏è Site Locations Map

Visualization of the ISMN TAHMO stations used for validation across Kenya and Ghana.

In [None]:
# Create map of site locations
fig, axes = plt.subplots(1, 2, figsize=(16, 7))

# Separate by region
kenya_sites = summary_df[summary_df['country'] == 'Kenya']
ghana_sites = summary_df[summary_df['country'] == 'Ghana']

# Color by correlation performance
cmap = plt.cm.RdYlGn

# Kenya map (East Africa)
ax1 = axes[0]
sc1 = ax1.scatter(kenya_sites['longitude'], kenya_sites['latitude'],
                  c=kenya_sites['r'], cmap=cmap, s=150, edgecolor='black',
                  linewidth=1.5, vmin=-0.2, vmax=1.0)
for _, row in kenya_sites.iterrows():
    ax1.annotate(row['station'][:15] + '...' if len(row['station']) > 15 else row['station'],
                 (row['longitude'], row['latitude']), fontsize=8,
                 xytext=(5, 5), textcoords='offset points')
ax1.set_xlabel('Longitude')
ax1.set_ylabel('Latitude')
ax1.set_title('üá∞üá™ Kenya Sites (East Africa)\n16 stations', fontsize=12, fontweight='bold')
ax1.set_xlim([33, 42])
ax1.set_ylim([-5, 1])
plt.colorbar(sc1, ax=ax1, label='Correlation (r)')

# Ghana map (West Africa)
ax2 = axes[1]
sc2 = ax2.scatter(ghana_sites['longitude'], ghana_sites['latitude'],
                  c=ghana_sites['r'], cmap=cmap, s=150, edgecolor='black',
                  linewidth=1.5, vmin=-0.2, vmax=1.0)
for _, row in ghana_sites.iterrows():
    ax2.annotate(row['station'][:15] + '...' if len(row['station']) > 15 else row['station'],
                 (row['longitude'], row['latitude']), fontsize=8,
                 xytext=(5, 5), textcoords='offset points')
ax2.set_xlabel('Longitude')
ax2.set_ylabel('Latitude')
ax2.set_title('üá¨üá≠ Ghana Sites (West Africa)\n6 stations', fontsize=12, fontweight='bold')
ax2.set_xlim([-3, 1])
ax2.set_ylim([5, 12])
plt.colorbar(sc2, ax=ax2, label='Correlation (r)')

plt.suptitle('üìç Validation Site Locations (Color = Model Performance)', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

print(f"\nüåç Geographic Coverage:")
print(f"   Kenya (East Africa): {len(kenya_sites)} sites, Lat range: {kenya_sites['latitude'].min():.2f}¬∞ to {kenya_sites['latitude'].max():.2f}¬∞")
print(f"   Ghana (West Africa): {len(ghana_sites)} sites, Lat range: {ghana_sites['latitude'].min():.2f}¬∞ to {ghana_sites['latitude'].max():.2f}¬∞")

## 4. üìä Performance by Measurement Depth

Analysis of how model performance varies with sensor depth (10cm to 60cm).

In [None]:
# Performance breakdown by measurement depth
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

depth_groups = summary_df.groupby('depth_m').agg({
    'mae': ['mean', 'std', 'count'],
    'rmse': ['mean', 'std'],
    'r': ['mean', 'std'],
    'bias': ['mean', 'std']
}).round(4)

depths = sorted(summary_df['depth_m'].unique())
depth_labels = [f"{int(d*100)}cm" for d in depths]
colors_depth = plt.cm.viridis(np.linspace(0.2, 0.8, len(depths)))

# MAE by depth
ax1 = axes[0, 0]
mae_means = [summary_df[summary_df['depth_m'] == d]['mae'].mean() for d in depths]
mae_stds = [summary_df[summary_df['depth_m'] == d]['mae'].std() for d in depths]
bars1 = ax1.bar(depth_labels, mae_means, yerr=mae_stds, capsize=5, color=colors_depth, edgecolor='black')
ax1.set_ylabel('MAE (m¬≥/m¬≥)')
ax1.set_xlabel('Sensor Depth')
ax1.set_title('Mean Absolute Error by Depth', fontweight='bold')
ax1.set_ylim(0, max(mae_means) * 1.5)
for bar, val in zip(bars1, mae_means):
    ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.005, f'{val:.3f}',
             ha='center', va='bottom', fontsize=10, fontweight='bold')

# RMSE by depth
ax2 = axes[0, 1]
rmse_means = [summary_df[summary_df['depth_m'] == d]['rmse'].mean() for d in depths]
rmse_stds = [summary_df[summary_df['depth_m'] == d]['rmse'].std() for d in depths]
bars2 = ax2.bar(depth_labels, rmse_means, yerr=rmse_stds, capsize=5, color=colors_depth, edgecolor='black')
ax2.set_ylabel('RMSE (m¬≥/m¬≥)')
ax2.set_xlabel('Sensor Depth')
ax2.set_title('Root Mean Square Error by Depth', fontweight='bold')
ax2.set_ylim(0, max(rmse_means) * 1.5)
for bar, val in zip(bars2, rmse_means):
    ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.005, f'{val:.3f}',
             ha='center', va='bottom', fontsize=10, fontweight='bold')

# Correlation by depth
ax3 = axes[1, 0]
r_means = [summary_df[summary_df['depth_m'] == d]['r'].mean() for d in depths]
r_stds = [summary_df[summary_df['depth_m'] == d]['r'].std() for d in depths]
bars3 = ax3.bar(depth_labels, r_means, yerr=r_stds, capsize=5, color=colors_depth, edgecolor='black')
ax3.set_ylabel('Correlation (r)')
ax3.set_xlabel('Sensor Depth')
ax3.set_title('Correlation Coefficient by Depth', fontweight='bold')
ax3.axhline(0, color='gray', linestyle='--', alpha=0.5)
for bar, val in zip(bars3, r_means):
    ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.03, f'{val:.3f}',
             ha='center', va='bottom', fontsize=10, fontweight='bold')

# Number of sites per depth
ax4 = axes[1, 1]
site_counts = [len(summary_df[summary_df['depth_m'] == d]) for d in depths]
bars4 = ax4.bar(depth_labels, site_counts, color=colors_depth, edgecolor='black')
ax4.set_ylabel('Number of Sites')
ax4.set_xlabel('Sensor Depth')
ax4.set_title('Sites Validated per Depth', fontweight='bold')
for bar, val in zip(bars4, site_counts):
    ax4.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.2, str(val),
             ha='center', va='bottom', fontsize=12, fontweight='bold')

plt.suptitle('üìè Model Performance by Measurement Depth', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

# Summary table
print("\nüìä Performance Summary by Depth:")
print("-" * 70)
print(f"{'Depth':<10} {'Sites':<8} {'MAE':<12} {'RMSE':<12} {'r':<10} {'Bias':<12}")
print("-" * 70)
for d in depths:
    df_d = summary_df[summary_df['depth_m'] == d]
    print(f"{int(d*100)}cm{'':<5} {len(df_d):<8} {df_d['mae'].mean():.4f}{'':<6} "
          f"{df_d['rmse'].mean():.4f}{'':<6} {df_d['r'].mean():.3f}{'':<5} {df_d['bias'].mean():.4f}")
print("-" * 70)

## 5. üåç Performance by Region (Kenya vs Ghana)

In [None]:
# Performance by region
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Region data
regions = ['West Africa', 'East Africa']
region_colors = ['#e74c3c', '#3498db']

# MAE by region
ax1 = axes[0]
for i, region in enumerate(regions):
    region_data = summary_df[summary_df['region'] == region]
    ax1.bar(i, region_data['mae'].mean(), yerr=region_data['mae'].std(),
            capsize=5, color=region_colors[i], edgecolor='black', label=f'{region}\n(n={len(region_data)})')
ax1.set_xticks(range(len(regions)))
ax1.set_xticklabels(['Ghana\n(West Africa)', 'Kenya\n(East Africa)'])
ax1.set_ylabel('MAE (m¬≥/m¬≥)')
ax1.set_title('Mean Absolute Error by Region', fontweight='bold')

# RMSE by region
ax2 = axes[1]
for i, region in enumerate(regions):
    region_data = summary_df[summary_df['region'] == region]
    ax2.bar(i, region_data['rmse'].mean(), yerr=region_data['rmse'].std(),
            capsize=5, color=region_colors[i], edgecolor='black')
ax2.set_xticks(range(len(regions)))
ax2.set_xticklabels(['Ghana\n(West Africa)', 'Kenya\n(East Africa)'])
ax2.set_ylabel('RMSE (m¬≥/m¬≥)')
ax2.set_title('RMSE by Region', fontweight='bold')

# Correlation by region
ax3 = axes[2]
for i, region in enumerate(regions):
    region_data = summary_df[summary_df['region'] == region]
    ax3.bar(i, region_data['r'].mean(), yerr=region_data['r'].std(),
            capsize=5, color=region_colors[i], edgecolor='black')
ax3.set_xticks(range(len(regions)))
ax3.set_xticklabels(['Ghana\n(West Africa)', 'Kenya\n(East Africa)'])
ax3.set_ylabel('Correlation (r)')
ax3.set_title('Correlation by Region', fontweight='bold')
ax3.axhline(0, color='gray', linestyle='--', alpha=0.5)

plt.suptitle('üåç Model Performance: West Africa (Ghana) vs East Africa (Kenya)', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

# Print regional summary
print("\nüìä Regional Performance Summary:")
print("-" * 60)
for region in regions:
    df_r = summary_df[summary_df['region'] == region]
    print(f"\n  {region}:")
    print(f"    Sites: {len(df_r)}")
    print(f"    MAE:  {df_r['mae'].mean():.4f} ¬± {df_r['mae'].std():.4f} m¬≥/m¬≥")
    print(f"    RMSE: {df_r['rmse'].mean():.4f} ¬± {df_r['rmse'].std():.4f} m¬≥/m¬≥")
    print(f"    r:    {df_r['r'].mean():.3f} ¬± {df_r['r'].std():.3f}")
print("-" * 60)

## 6. üìà Scatter Plots: Predicted vs Observed Soil Moisture

Comparison of model predictions against in-situ observations for selected sites.

In [None]:
# Scatter plots for top and bottom performing sites
fig, axes = plt.subplots(2, 4, figsize=(18, 10))

# Sort sites by correlation
best_sites = summary_df.nlargest(4, 'r')
worst_sites = summary_df.nsmallest(4, 'r')

# Plot best performing sites
for ax, (_, site_info) in zip(axes[0], best_sites.iterrows()):
    site_data = daily_df[daily_df['site_id'] == site_info['site_id']]
    obs = site_data['soil_moisture'].values
    pred = site_data['sm_model'].values

    ax.scatter(obs, pred, alpha=0.4, s=15, c='#2ecc71')
    ax.plot([0, 0.6], [0, 0.6], 'k--', linewidth=2, label='1:1 Line')
    ax.set_xlabel('Observed SM (m¬≥/m¬≥)')
    ax.set_ylabel('Predicted SM (m¬≥/m¬≥)')
    ax.set_title(f"‚úÖ {site_info['station'][:20]}\nr={site_info['r']:.3f}, RMSE={site_info['rmse']:.3f}", fontsize=10)
    ax.set_xlim(0, 0.6)
    ax.set_ylim(0, 0.6)
    ax.set_aspect('equal')

# Plot worst performing sites
for ax, (_, site_info) in zip(axes[1], worst_sites.iterrows()):
    site_data = daily_df[daily_df['site_id'] == site_info['site_id']]
    obs = site_data['soil_moisture'].values
    pred = site_data['sm_model'].values

    ax.scatter(obs, pred, alpha=0.4, s=15, c='#e74c3c')
    ax.plot([0, 0.6], [0, 0.6], 'k--', linewidth=2, label='1:1 Line')
    ax.set_xlabel('Observed SM (m¬≥/m¬≥)')
    ax.set_ylabel('Predicted SM (m¬≥/m¬≥)')
    ax.set_title(f"‚ö†Ô∏è {site_info['station'][:20]}\nr={site_info['r']:.3f}, RMSE={site_info['rmse']:.3f}", fontsize=10)
    ax.set_xlim(0, 0.6)
    ax.set_ylim(0, 0.6)
    ax.set_aspect('equal')

axes[0, 0].set_ylabel('Predicted SM (m¬≥/m¬≥)\n\n[BEST SITES]', fontsize=11, fontweight='bold')
axes[1, 0].set_ylabel('Predicted SM (m¬≥/m¬≥)\n\n[CHALLENGING SITES]', fontsize=11, fontweight='bold')

plt.suptitle('üìà Predicted vs Observed Soil Moisture: Best & Most Challenging Sites', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

## 7. üìÖ Time Series Comparison: Model vs Observations

Detailed time series showing how the model tracks soil moisture dynamics over time.

In [None]:
# Time series for best sites from each region
fig, axes = plt.subplots(3, 1, figsize=(16, 12))

# Select sites with best correlation from each depth category
top_sites_by_depth = []
for depth in [0.1, 0.3, 0.6]:
    depth_sites = summary_df[summary_df['depth_m'] == depth]
    if len(depth_sites) > 0:
        best_site = depth_sites.nlargest(1, 'r').iloc[0]
        top_sites_by_depth.append(best_site)

for ax, site_info in zip(axes, top_sites_by_depth):
    site_data = daily_df[daily_df['site_id'] == site_info['site_id']].copy()
    site_data = site_data.sort_values('date')

    # Plot observed and predicted
    ax.plot(site_data['date'], site_data['soil_moisture'],
            'b-', alpha=0.7, linewidth=1.5, label='Observed (ISMN)')
    ax.plot(site_data['date'], site_data['sm_model'],
            'r-', alpha=0.7, linewidth=1.5, label='Predicted (SMPS)')

    # Add precipitation on secondary axis
    ax2 = ax.twinx()
    ax2.bar(site_data['date'], site_data['precip_mm'],
            alpha=0.3, color='skyblue', width=1, label='Precipitation')
    ax2.invert_yaxis()
    ax2.set_ylabel('Precipitation (mm)', color='skyblue')
    ax2.set_ylim(100, 0)

    ax.set_ylabel('Soil Moisture (m¬≥/m¬≥)')
    ax.set_title(f"üìç {site_info['station']} ({site_info['country']}) @ {int(site_info['depth_m']*100)}cm | "
                 f"r={site_info['r']:.3f}, MAE={site_info['mae']:.4f}", fontsize=11, fontweight='bold')
    ax.legend(loc='upper left')
    ax.set_ylim(0, 0.6)
    ax.grid(True, alpha=0.3)

axes[-1].set_xlabel('Date')
plt.suptitle('üìÖ Time Series: Model Predictions vs In-Situ Observations (2020-2021)', fontsize=14, fontweight='bold', y=1.01)
plt.tight_layout()
plt.show()

## 8. ‚è±Ô∏è Prediction Window Analysis (24hr, 72hr, 7-day)

Evaluate model accuracy over different prediction horizons using a rolling window approach.

In [None]:
def calculate_window_metrics(site_data, window_days):
    """Calculate metrics for a given prediction window."""
    site_data = site_data.sort_values('date').copy()

    errors = []
    for i in range(len(site_data) - window_days):
        # Start from observation i, predict forward window_days
        start_val = site_data.iloc[i]['soil_moisture']
        end_idx = i + window_days

        # Get predicted and observed at end of window
        pred_val = site_data.iloc[end_idx]['sm_model']
        obs_val = site_data.iloc[end_idx]['soil_moisture']

        errors.append(abs(pred_val - obs_val))

    if errors:
        return {
            'mae': np.mean(errors),
            'rmse': np.sqrt(np.mean(np.array(errors)**2)),
            'n': len(errors)
        }
    return None

# Analyze different prediction windows
windows = [1, 3, 7, 14, 30]  # days: 24hr, 72hr, 1 week, 2 weeks, 1 month
window_labels = ['24hr', '72hr', '7 days', '14 days', '30 days']

# Calculate metrics for each window across all sites
window_results = {w: {'mae': [], 'rmse': []} for w in windows}

for site_id in daily_df['site_id'].unique():
    site_data = daily_df[daily_df['site_id'] == site_id]
    if len(site_data) > 30:
        for w in windows:
            metrics = calculate_window_metrics(site_data, w)
            if metrics:
                window_results[w]['mae'].append(metrics['mae'])
                window_results[w]['rmse'].append(metrics['rmse'])

# Plot results
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# MAE by window
ax1 = axes[0]
mae_means = [np.mean(window_results[w]['mae']) for w in windows]
mae_stds = [np.std(window_results[w]['mae']) for w in windows]
colors_window = plt.cm.plasma(np.linspace(0.2, 0.8, len(windows)))
bars1 = ax1.bar(window_labels, mae_means, yerr=mae_stds, capsize=5, color=colors_window, edgecolor='black')
ax1.set_xlabel('Prediction Window')
ax1.set_ylabel('MAE (m¬≥/m¬≥)')
ax1.set_title('Mean Absolute Error by Prediction Window', fontweight='bold')
for bar, val in zip(bars1, mae_means):
    ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.005, f'{val:.4f}',
             ha='center', va='bottom', fontsize=9, fontweight='bold')

# RMSE by window
ax2 = axes[1]
rmse_means = [np.mean(window_results[w]['rmse']) for w in windows]
rmse_stds = [np.std(window_results[w]['rmse']) for w in windows]
bars2 = ax2.bar(window_labels, rmse_means, yerr=rmse_stds, capsize=5, color=colors_window, edgecolor='black')
ax2.set_xlabel('Prediction Window')
ax2.set_ylabel('RMSE (m¬≥/m¬≥)')
ax2.set_title('RMSE by Prediction Window', fontweight='bold')
for bar, val in zip(bars2, rmse_means):
    ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.005, f'{val:.4f}',
             ha='center', va='bottom', fontsize=9, fontweight='bold')

plt.suptitle('‚è±Ô∏è Model Accuracy Across Different Prediction Horizons', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

# Summary table
print("\nüìä Prediction Window Performance Summary:")
print("-" * 60)
print(f"{'Window':<15} {'MAE (m¬≥/m¬≥)':<18} {'RMSE (m¬≥/m¬≥)':<18}")
print("-" * 60)
for w, label in zip(windows, window_labels):
    mae = np.mean(window_results[w]['mae'])
    rmse = np.mean(window_results[w]['rmse'])
    print(f"{label:<15} {mae:.4f} ¬± {np.std(window_results[w]['mae']):.4f}{'':<4} "
          f"{rmse:.4f} ¬± {np.std(window_results[w]['rmse']):.4f}")
print("-" * 60)

## 9. üìä Error Distribution Analysis

Distribution of prediction errors to understand model behavior.