In [None]:
#!/usr/bin/env python
# coding: utf-8

import os
import warnings
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.svm import SVR
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error

# Suppress warnings
os.environ['PYTHONWARNINGS'] = 'ignore::FutureWarning'
warnings.filterwarnings("ignore", category=FutureWarning)

def run_loso_svm_analysis(target_col):
    """
    Performs a full Leave-One-Site-Out (LOSO) cross-validation for a given target variable
    using a Support Vector Machine (SVM) with a radial basis function kernel.

    The process includes:
    1. Loading and preprocessing data.
    2. One-hot encoding categorical features.
    3. Applying a pipeline that standardizes features and runs the SVR model.
    4. Calculating and saving performance metrics and predictions.
    5. Generating and saving time-series plots for each site.
    """
    print(f"Processing target variable: {target_col}")

    # --- 1. Load and Prepare Dataset ---
    file_path = "/explore/nobackup/people/spotter5/anna_v/v2/v2_model_training_final.csv"
    df = pd.read_csv(file_path)
    df = df[df['flux_method'] == 'EC']
    
    # Create derived features
    df['tmean_C'] = df[['tmmn', 'tmmx']].mean(axis=1)
    df['date'] = pd.to_datetime(df[['year', 'month']].assign(day=1))

    # --- 2. Define Predictors and Target ---
    feature_cols = [
        'EVI', 'NDVI', 'sur_refl_b01', 'sur_refl_b02', 'sur_refl_b03',
        'sur_refl_b07', 'NDWI', 'pdsi', 'srad', 'tmean_C', 'vap', 'vs',
        'bdod_0_100cm', 'cec_0_100cm', 'cfvo_0_100cm', 'clay_0_100cm',
        'nitrogen_0_100cm', 'ocd_0_100cm', 'phh2o_0_100cm', 'sand_0_100cm',
        'silt_0_100cm', 'soc_0_100cm', 'co2_cont', 'ALT',
        'land_cover', 'month',
        'lai', 'fpar', 'Percent_NonTree_Vegetation',
        'Percent_NonVegetated', 'Percent_Tree_Cover'
    ]
    categorical_features = ['land_cover', 'month']

    # --- 3. Preprocessing for SVM (Handle NaNs, Encode Categoricals) ---
    # SVM requires complete data, so we drop NaNs from all required columns
    required_cols = feature_cols + ['site_reference', target_col]
    df = df.dropna(subset=required_cols)

    # Set up output paths
    out_path = os.path.join("/explore/nobackup/people/spotter5/anna_v/v2/loocv", target_col)
    figures_path = os.path.join(out_path, "figures", "svm") # Create nested directory for SVM plots
    os.makedirs(figures_path, exist_ok=True)

    # Prepare features (X) and target (y)
    X_initial = df[feature_cols].copy()
    y = df[target_col]
    sites = df["site_reference"].unique()
    
    # Convert categorical features to 'category' dtype before one-hot encoding
    for col in categorical_features:
        X_initial[col] = X_initial[col].astype('category')
    
    # One-hot encode. This must be done before the CV split.
    print("  One-hot encoding categorical features...")
    X = pd.get_dummies(X_initial, columns=categorical_features, drop_first=True)

    results = []
    all_preds_df_list = []

    # --- 4. Leave-One-Site-Out Cross-Validation ---
    for test_site in sites:
        print(f"  Processing site: {test_site}...")
        train_idx = df["site_reference"] != test_site
        test_idx = df["site_reference"] == test_site

        if test_idx.sum() < 1:
            continue

        X_train, y_train = X.loc[train_idx], y.loc[train_idx]
        X_test, y_test = X.loc[test_idx], y.loc[test_idx]
        dates_test = df.loc[test_idx, "date"]

        # Define the pipeline: StandardScaler -> SVR
        # model = make_pipeline(StandardScaler(), SVR(kernel='rbf'))
        model = make_pipeline(StandardScaler(), SVR(kernel='rbf', C=100))

        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)

        # Store predictions
        site_df = pd.DataFrame({
            "Site": test_site, "Date": dates_test.values,
            "Observed": y_test.values, "Predicted": y_pred
        })
        all_preds_df_list.append(site_df)

        # Calculate and store metrics
        rmse = np.sqrt(mean_squared_error(y_test, y_pred))
        r2 = r2_score(y_test, y_pred)
        mae = mean_absolute_error(y_test, y_pred)
        results.append({"Site": test_site, "RMSE": rmse, "MAE": mae, "R2": r2})

    # --- 5. Aggregate and Save Results ---
    if not results:
        print(f"No data processed for target '{target_col}'.")
        return

    results_df = pd.DataFrame(results)
    all_preds_df = pd.concat(all_preds_df_list, ignore_index=True)

    # Save results to disk with 'svm_' prefix
    results_csv_path = os.path.join(out_path, f'svm_results_{target_col}.csv')
    predictions_csv_path = os.path.join(out_path, f'svm_predictions_{target_col}.csv')
    results_df.to_csv(results_csv_path, index=False)
    all_preds_df.to_csv(predictions_csv_path, index=False)
    print(f"\n  Results saved to: {results_csv_path}")
    print(f"  Predictions saved to: {predictions_csv_path}")

    # --- 6. Report Pooled and Median Metrics ---
    rmse_all = np.sqrt(mean_squared_error(all_preds_df["Observed"], all_preds_df["Predicted"]))
    r2_all = r2_score(all_preds_df["Observed"], all_preds_df["Predicted"])
    mae_all = mean_absolute_error(all_preds_df["Observed"], all_preds_df["Predicted"])

    print("\n--- Pooled Metrics ---")
    print(f"Pooled RMSE: {rmse_all:.4f}")
    print(f"Pooled MAE:  {mae_all:.4f}")
    print(f"Pooled R²:   {r2_all:.4f}")

    median_r2 = results_df["R2"].median()
    print("\n--- Median Metrics Across Sites ---")
    print(f"Median R²:   {median_r2:.4f}")

    # --- 7. Plotting ---
    print("\n  Generating and saving individual site plots...")
    for site in all_preds_df["Site"].unique():
        fig, ax = plt.subplots(figsize=(12, 7))
        
        site_df = all_preds_df[all_preds_df["Site"] == site].sort_values("Date")
        site_metrics = results_df[results_df["Site"] == site].iloc[0]

        ax.plot(site_df["Date"], site_df["Observed"], label="Observed", marker="o", linestyle='-', markersize=4)
        ax.plot(site_df["Date"], site_df["Predicted"], label="Predicted", marker="x", linestyle='--', markersize=4)
        ax.set_title(f"Observed vs. Predicted {target_col} (SVM) for Site: {site}")
        ax.legend(), ax.grid(True), fig.autofmt_xdate()

        textstr = f"RMSE: {site_metrics['RMSE']:.2f}\nMAE: {site_metrics['MAE']:.2f}\nR²: {site_metrics['R2']:.2f}"
        ax.text(0.97, 0.03, textstr, transform=ax.transAxes, fontsize=10,
                verticalalignment='bottom', horizontalalignment='right',
                bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.7))
        
        # Define the output path and save the plot
        plot_filename = f'svm_{target_col}_{site}_timeseries.png'
        plot_path = os.path.join(figures_path, plot_filename)
        plt.savefig(plot_path, dpi=300, bbox_inches='tight')
        plt.close(fig)
        
    print(f"  All site plots saved to: {figures_path}")


if __name__ == '__main__':
    # targets_to_run = ['gpp', 'nee', 'reco', 'ch4_flux_total']
    targets_to_run = ['nee']

    for target in targets_to_run:
        print(f"\n{'='*30}\n RUNNING SVM ANALYSIS FOR: {target.upper()}\n{'='*30}")
        run_loso_svm_analysis(target_col=target)
        print(f"\n{'='*30}\n COMPLETED SVM ANALYSIS FOR: {target.upper()}\n{'='*30}")


 RUNNING SVM ANALYSIS FOR: NEE
Processing target variable: nee
  One-hot encoding categorical features...
  Processing site: Hakasia 5yr_RU-Ha2_tower...
  Processing site: Hakasia Steppe_RU-Ha1_tower...
  Processing site: Kaamanen_FI-Kaa_tower...
  Processing site: Manitoba - Northern Old Black Spruce (former BOREAS Northern Study Area)_CA-Man_tower...
  Processing site: Nelegel_RU-Nel_tower...
  Processing site: Neleger Cutover_RU-NeC_tower...
  Processing site: Neleger larch forest_RU-NeF_tower...
  Processing site: Saskatchewan - Western Boreal, forest burned in 1989_CA-SF2_tower...
  Processing site: Sodankyla_FI-Sod_tower...
  Processing site: UCI-1850 burn site_CA-NS1_tower...
  Processing site: UCI-1964 burn site_CA-NS3_tower...
  Processing site: UCI-1981 burn site_CA-NS5_tower...
  Processing site: UCI-1989 burn site_CA-NS6_tower...
  Processing site: UCI-1998 burn site_CA-NS7_tower...
  Processing site: UCI-1964 burn site wet_CA-NS4_tower...
  Processing site: Delta Junction



  Processing site: Pond Inlet_CA-Pin_tower...




  Processing site: Poker Flat Research Range: Succession from fire scar to deciduous forest_US-Rpf_tower...
  Processing site: Samoylov Island_RU-Sam (open)_tower...
  Processing site: Udleg practice forest_MN-Udg_tower...
  Processing site: Daring Lake_CA-DL3_tower...
  Processing site: Elgeeii forest station_RU-Ege_tower...
  Processing site: Bonanza Creek Black Spruce_US-BZS_tower...
  Processing site: Daring Lake_CA-DL4_tower...
  Processing site: Tiksi_RU-Tks_tower...
  Processing site: Samoylov Island_RU-Sam (closed)_tower...
  Processing site: Bonanza Creek Thermokarst Bog_US-BZB_tower...
  Processing site: Poker Flat Research Range Black Spruce Forest_US-Prr_tower...
  Processing site: Bonanza Creek Rich Fen_US-BZF_tower...
  Processing site: Cascaden Ridge Fire Scar_US-Fcr_tower...




  Processing site: Lake Hazen, Ellesmere Island_CA-LHazen2-meadow wetland_tower...
  Processing site: Cherskii ecotone_RU-Eusk_cher1_tower...
  Processing site: Sammaltunturi fell_FI-SamFell_tower...
  Processing site: ARM-NSA-Barrow_US-A10_tower...
  Processing site: Adventdalen_SJ-Adv_tower...
  Processing site: Stordalen Fen_SE-St1_tower...
  Processing site: NGEE Arctic Barrow_US-NGB_tower...
  Processing site: Cherskii disturbed forest_RU-Eusk_cher2_tower...
  Processing site: Disko_GL-Dsk_tower...
  Processing site: Havikpak Creek_CA-HPC_tower...
  Processing site: Scotty Creek Landscape_CA-SCC_tower...
  Processing site: ZOTTO Bog_RU-Zo1_tower...
  Processing site: ZOTTO Forest_RU-Zo2_tower...
  Processing site: Trail Valley Creek_CA-TVC_tower...
  Processing site: Cherskii reference_RU-Ch2_tower...
  Processing site: Flux Observations of Carbon from an Airborne Laboratory (FOCAL) Campaign Site 1_US-Fo1_tower...
  Processing site: Barrow-CMDL_US-Brw_tower...
  Processing site: S

KeyboardInterrupt: 