In [2]:
import os
import pandas as pd
from PIL import Image
import warnings

# Suppress specific PIL warnings if they appear
warnings.filterwarnings("ignore", message="Possibly corrupt EXIF data. ")

def combine_and_save_plots(target_col):
    """
    Finds existing CatBoost and SVM plots for each site,
    combines them side-by-side, and saves a new image.
    """
    # --- 1. Define Paths ---
    base_path = "/explore/nobackup/people/spotter5/anna_v/v2/loocv"
    
    # Paths to the existing, individual plot directories
    catboost_figures_path = os.path.join(base_path, target_col, "figures")
    svm_figures_path = os.path.join(base_path, target_col, "figures_svm_top_features")
    
    # Path for the new combined output images
    comparison_output_path = os.path.join(base_path, target_col, "figures", "svm_v_cat_combined")
    os.makedirs(comparison_output_path, exist_ok=True)

    print(f"\n--- Processing Target: {target_col.upper()} ---")
    print(f"Reading images from: {catboost_figures_path} and {svm_figures_path}")

    # --- 2. Find all unique sites from the filenames ---
    try:
        # Get a list of all catboost plot files
        cb_files = os.listdir(catboost_figures_path)
    except FileNotFoundError:
        print(f"❌ Error: Directory not found: {catboost_figures_path}. Skipping.")
        return

    # Extract site names (assuming format 'model_target_sitename_...png')
    sites = []
    for f in cb_files:
        if f.startswith(f'catboost_{target_col}_') and f.endswith('.png'):
            # Split the filename and extract the site part
            try:
                site_name = f.split(f'catboost_{target_col}_')[1].split('_timeseries_cat.png')[0]
                sites.append(site_name)
            except IndexError:
                continue # Skip files that don't match the expected format
    
    if not sites:
        print("No site plots found to combine.")
        return
        
    print(f"Found {len(sites)} unique sites. Combining images...")

    # --- 3. Loop, Combine, and Save ---
    for site in sites:
        # Define the full paths to the two images to be combined
        catboost_img_path = os.path.join(catboost_figures_path, f'catboost_{target_col}_{site}_timeseries_cat.png')
        svm_img_path = os.path.join(svm_figures_path, f'svm_{target_col}_{site}_timeseries_top_features.png')

        try:
            # Open the two images
            img_cb = Image.open(catboost_img_path)
            img_svm = Image.open(svm_img_path)

            # Get dimensions
            width_cb, height_cb = img_cb.size
            width_svm, height_svm = img_svm.size

            # Create a new blank canvas (width = sum of both, height = max of both)
            new_width = width_cb + width_svm
            new_height = max(height_cb, height_svm)
            combined_image = Image.new('RGB', (new_width, new_height), 'white')

            # Paste the CatBoost image on the left
            combined_image.paste(img_cb, (0, 0))
            # Paste the SVM image on the right
            combined_image.paste(img_svm, (width_cb, 0))
            
            # Define the output path for the new combined image
            output_filename = f'comparison_{target_col}_{site}.png'
            output_path = os.path.join(comparison_output_path, output_filename)
            
            # Save the new image
            combined_image.save(output_path)
            print(f"  - Saved combined plot for site: {site}")

        except FileNotFoundError as e:
            print(f"  - ❌ Skipping site '{site}': Missing one or both image files. Details: {e.filename}")
        except Exception as e:
            print(f"  - ❌ An error occurred for site '{site}': {e}")
            
    print(f"\n✅ Combination complete for '{target_col}'. Plots saved to: {comparison_output_path}")


if __name__ == '__main__':
    # List of target variables to run the analysis for
    targets_to_run = ['nee', 'gpp', 'reco', 'ch4_flux_total']

    for target in targets_to_run:
        combine_and_save_plots(target_col=target)


--- Processing Target: NEE ---
Reading images from: /explore/nobackup/people/spotter5/anna_v/v2/loocv/nee/figures and /explore/nobackup/people/spotter5/anna_v/v2/loocv/nee/figures_svm_top_features
Found 555 unique sites. Combining images...
  - ❌ Skipping site 'Fyodorovskoye_RU-Fyo_tower': Missing one or both image files. Details: /panfs/ccds02/nobackup/people/spotter5/anna_v/v2/loocv/nee/figures_svm_top_features/svm_nee_Fyodorovskoye_RU-Fyo_tower_timeseries_top_features.png
  - Saved combined plot for site: Saskatchewan - Western Boreal, Mature Aspen_CA-Oas_tower
  - ❌ Skipping site 'Hyytiala_FI-Hyy_tower': Missing one or both image files. Details: /panfs/ccds02/nobackup/people/spotter5/anna_v/v2/loocv/nee/figures_svm_top_features/svm_nee_Hyytiala_FI-Hyy_tower_timeseries_top_features.png
  - Saved combined plot for site: Manitoba - Northern Old Black Spruce (former BOREAS Northern Study Area)_CA-Man_tower
  - Saved combined plot for site: Saskatchewan - Western Boreal, Mature Black S

In [None]:
import os
import warnings
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from catboost import CatBoostRegressor
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error

# Suppress warnings
os.environ['PYTHONWARNINGS'] = 'ignore::FutureWarning'
warnings.filterwarnings("ignore", category=FutureWarning)

def run_loso_analysis(target_col):
    """
    Performs a full Leave-One-Site-Out (LOSO) cross-validation for a given target variable.
    """
    print(f"--- Processing Target: {target_col.upper()} ---")
    file_path = "/explore/nobackup/people/spotter5/anna_v/v2/v2_model_training_final.csv"
    df = pd.read_csv(file_path)
    df['land_cover'] = df['land_cover'].astype(int)
    df['month'] = df['month'].astype(int)
    df = df[df['flux_method'] == 'EC']

    df['tmean_C'] = df[['tmmn', 'tmmx']].mean(axis=1)
    df['date'] = pd.to_datetime(df[['year', 'month']].assign(day=1))

    feature_cols = [
        'EVI', 'NDVI', 'sur_refl_b01', 'sur_refl_b02', 'sur_refl_b03',
        'sur_refl_b07', 'NDWI', 'pdsi', 'srad', 'tmean_C', 'vap', 'vs',
        'bdod_0_100cm', 'cec_0_100cm', 'cfvo_0_100cm', 'clay_0_100cm',
        'nitrogen_0_100cm', 'ocd_0_100cm', 'phh2o_0_100cm', 'sand_0_100cm',
        'silt_0_100cm', 'soc_0_100cm', 'co2_cont', 'ALT',
        'land_cover', 'month',
        'lai', 'fpar', 'Percent_NonTree_Vegetation',
        'Percent_NonVegetated', 'Percent_Tree_Cover'
    ]
    categorical_features = ['land_cover', 'month']
    df = df.dropna(subset=['site_reference', target_col])

    loocv_out_path = os.path.join("/explore/nobackup/people/spotter5/anna_v/v2/loocv", target_col)
    figures_path = os.path.join(loocv_out_path, "figures")
    models_out_path = '/explore/nobackup/people/spotter5/anna_v/v2/models'
    os.makedirs(loocv_out_path, exist_ok=True)
    os.makedirs(figures_path, exist_ok=True)
    os.makedirs(models_out_path, exist_ok=True)

    X = df[feature_cols].copy()
    y = df[target_col]
    sites = df["site_reference"].unique()

    for col in categorical_features:
        X[col] = X[col].astype('category')

    results = []
    all_preds_df_list = []

    for test_site in sites:
        print(f"  Processing site: {test_site}...")
        train_idx = df["site_reference"] != test_site
        test_idx = df["site_reference"] == test_site

        if test_idx.sum() < 1:
            continue

        X_train, y_train = X.loc[train_idx], y.loc[train_idx]
        X_test, y_test = X.loc[test_idx], y.loc[test_idx]
        dates_test = df.loc[test_idx, "date"]

        # Using CatBoost's native parameter 'rsm' (alias for colsample_bytree)
        model = CatBoostRegressor(
            iterations=1200,
            learning_rate=0.01,
            depth=8,
            subsample=0.7,
            random_state=42,
            l2_leaf_reg=0.1,
            rsm=0.8, # Correct native parameter for CatBoost
            cat_features=categorical_features,
            verbose=0,
            allow_writing_files=False
        )
        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)

        site_df = pd.DataFrame({"Site": test_site, "Date": dates_test.values, "Observed": y_test.values, "Predicted": y_pred})
        all_preds_df_list.append(site_df)
        rmse = np.sqrt(mean_squared_error(y_test, y_pred))
        r2 = r2_score(y_test, y_pred)
        mae = mean_absolute_error(y_test, y_pred)
        results.append({"Site": test_site, "RMSE": rmse, "MAE": mae, "R2": r2})

    if not results:
        print(f"No data processed for target '{target_col}'. Skipping.")
        return

    results_df = pd.DataFrame(results)
    all_preds_df = pd.concat(all_preds_df_list, ignore_index=True)

    results_csv_path = os.path.join(loocv_out_path, f'catboost_results_{target_col}_cat.csv')
    predictions_csv_path = os.path.join(loocv_out_path, f'catboost_predictions_{target_col}_cat.csv')
    results_df.to_csv(results_csv_path, index=False)
    all_preds_df.to_csv(predictions_csv_path, index=False)
    print(f"\n  Results saved to: {results_csv_path}")

    rmse_all = np.sqrt(mean_squared_error(all_preds_df["Observed"], all_preds_df["Predicted"]))
    r2_all = r2_score(all_preds_df["Observed"], all_preds_df["Predicted"])
    mae_all = mean_absolute_error(all_preds_df["Observed"], all_preds_df["Predicted"])
    print(f"\n  --- Pooled Metrics for {target_col.upper()} ---")
    print(f"  Pooled R²: {r2_all:.4f}, Pooled RMSE: {rmse_all:.4f}, Pooled MAE: {mae_all:.4f}")

    print("\n  Generating and saving individual site plots...")
    for site in all_preds_df["Site"].unique():
        fig, ax = plt.subplots(figsize=(12, 7))
        site_df = all_preds_df[all_preds_df["Site"] == site].sort_values("Date")
        site_metrics = results_df[results_df["Site"] == site].iloc[0]

        ax.plot(site_df["Date"], site_df["Observed"], label="Observed", marker="o", linestyle='-', markersize=4)
        ax.plot(site_df["Date"], site_df["Predicted"], label="Predicted", marker="x", linestyle='--', markersize=4)
        ax.set_title(f"Observed vs. Predicted {target_col} for Site: {site}")
        ax.legend(), ax.grid(True), fig.autofmt_xdate()

        textstr = f"RMSE: {site_metrics['RMSE']:.2f}\nMAE: {site_metrics['MAE']:.2f}\nR²: {site_metrics['R2']:.2f}"
        ax.text(0.97, 0.03, textstr, transform=ax.transAxes, fontsize=10,
                verticalalignment='bottom', horizontalalignment='right',
                bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.7))

        plot_filename = f'catboost_{target_col}_{site}_timeseries_cat.png'
        plot_path = os.path.join(figures_path, plot_filename)
        plt.savefig(plot_path, dpi=300, bbox_inches='tight')
        plt.close(fig)
    print(f"  All site plots saved to: {figures_path}")

    print("\n  Training and saving final model on all data...")
    final_model = CatBoostRegressor(
        iterations=1200, learning_rate=0.01, depth=8, subsample=0.7,
        random_state=42, l2_leaf_reg=0.1, rsm=0.8,
        cat_features=categorical_features, verbose=0, allow_writing_files=False
    )
    final_model.fit(X, y)
    model_filename = os.path.join(models_out_path, f'{target_col}.json')
    final_model.save_model(model_filename)
    print(f"  Final model saved to: {model_filename}")

if __name__ == '__main__':
    targets_to_run = ['gpp', 'nee', 'reco', 'ch4_flux_total']
    for target in targets_to_run:
        print(f"\n{'='*50}\nRUNNING ANALYSIS FOR: {target.upper()}\n{'='*50}")
        run_loso_analysis(target_col=target)
        print(f"\n{'='*50}\nCOMPLETED ANALYSIS FOR: {target.upper()}\n{'='*50}")