In [7]:
import pandas as pd
import matplotlib.pyplot as plt
import os
import warnings

# Ignore routine warnings for cleaner output
warnings.filterwarnings('ignore')

# ---------------- Font size defaults ----------------
FS_TITLE   = 26   # suptitle
FS_AXLABEL = 22   # x/y axis labels
FS_TICKS   = 18   # x/y tick labels
FS_LEGEND  = 20   # legend text
FS_SUBTITLE = 22  # per-panel title

# 1. Load your dataset
print("--- 📂 Loading dataset ---")
try:
    df_new = pd.read_csv("/explore/nobackup/people/spotter5/anna_v/v2/v2_model_training_final.csv")
    df_new = df_new[df_new['flux_method'] == 'EC']

    # --- UNITS CONVERSION and FEATURE ENGINEERING ---
    df_new['tmean_C'] = df_new[['tmmn', 'tmmx']].mean(axis=1)
    df_new['date'] = pd.to_datetime(df_new[['year', 'month']].assign(day=1))
    
except FileNotFoundError as e:
    print(f"Error: The data file was not found.\n{e}")
    raise
except KeyError as e:
    print(f"Error: A required column is missing: {e}")
    raise

# 2. Define output path and get site list
comparison_plot_path = os.path.join(
    "/explore/nobackup/people/spotter5/anna_v/v2/exploration", 
    "variable_comparisons_methane"
)
os.makedirs(comparison_plot_path, exist_ok=True)
print(f"Plots will be saved to: {comparison_plot_path}")

all_sites = df_new['site_reference'].dropna().unique()
print(f"Found {len(all_sites)} unique sites to process.")

# --- helper plotting function ---
def create_comparison_plot(ax, site_data, var_name, var_unit, var_color):
    ax_twin = ax.twinx()

    # CH4 on primary axis
    line_ch4 = ax.plot(
        site_data['date'], site_data['ch4_flux_total'],
        color='lightgray', label='CH4', linewidth=2.5, zorder=1
    )
    ax.set_ylabel('CH4', color='gray', fontsize=FS_AXLABEL)
    ax.tick_params(axis='y', labelcolor='gray', labelsize=FS_TICKS)
    ax.grid(True, linestyle='--', alpha=0.5)

    # Variable on secondary axis
    line_var = ax_twin.plot(
        site_data['date'], site_data[var_name],
        color=var_color, linestyle='-', label=var_name, zorder=3
    )
    ax_twin.set_ylabel(f'{var_name} ({var_unit})', color=var_color, fontsize=FS_AXLABEL)
    ax_twin.tick_params(axis='y', labelcolor=var_color, labelsize=FS_TICKS)

    ax.set_title(f'CH4 vs. {var_name}', fontsize=FS_SUBTITLE, fontweight='bold')

    # Combined legend
    all_handles = line_ch4 + line_var
    all_labels = [h.get_label() for h in all_handles]
    ax.legend(all_handles, all_labels, loc='upper left', fontsize=FS_LEGEND)

# --- variables to plot ---
plots = [
    ('NDVI',        'unitless',  'forestgreen'),
    ('pr',          'mm',        'dodgerblue'),
    ('tmean_C',     '°C',        'indianred'),
    ('snow_depth',  'meters',    'saddlebrown'),
    ('sm_surface',  'Volumetric','purple'),
    ('sm_rootzone', 'Volumetric','darkorange'),
]

# --- plotting loop ---
for site in all_sites:
    site_df = (
        df_new[df_new['site_reference'] == site]
        .dropna(subset=['ch4_flux_total'])
        .sort_values('date')
    )

    if site_df.empty:
        print(f" -> Skipping site: {site} (No valid CH4 observations)")
        continue

    print(f" -> Processing site: {site}")

    # 3x2 figure
    fig, axes = plt.subplots(3, 2, figsize=(30, 20), sharex=True)
    fig.suptitle(f"Site: {site} - Variable Comparison with CH4",
                 fontsize=FS_TITLE, fontweight='bold')

    ax_list = axes.flat
    for ax, (vname, vunit, vcolor) in zip(ax_list, plots):
        if vname not in site_df.columns:
            ax.text(0.5, 0.5, f"Missing: {vname}", ha='center', va='center', fontsize=FS_SUBTITLE)
            ax.set_axis_off()
            continue
        create_comparison_plot(ax, site_df, vname, vunit, vcolor)

    # common x-axis formatting
    for ax in ax_list:
        ax.set_xlabel('Date', fontsize=FS_AXLABEL)
        ax.tick_params(axis='x', labelsize=FS_TICKS)

    fig.tight_layout(rect=[0, 0, 1, 0.96])
    plt.savefig(os.path.join(comparison_plot_path, f"{site}.png"), dpi=150)
    plt.close(fig)

print("\n--- ✅ Finished plotting. All site comparison plots are saved. ---")


--- 📂 Loading dataset ---
Plots will be saved to: /explore/nobackup/people/spotter5/anna_v/v2/exploration/variable_comparisons_methane
Found 188 unique sites to process.
 -> Skipping site: Skyttorp 2_SE-Sk2_tower (No valid CH4 observations)
 -> Skipping site: Wolf_creek_forest_CA-WCF_tower (No valid CH4 observations)
 -> Skipping site: Alberta - Western Peatland - LaBiche River,Black Spruce,Larch Fen_CA-WP1_tower (No valid CH4 observations)
 -> Skipping site: Elgeeii forest station_RU-Ege_tower (No valid CH4 observations)
 -> Skipping site: Faejemyr_SE-Faj_tower (No valid CH4 observations)
 -> Processing site: Fyodorovskoye2_RU-Fy2_tower
 -> Skipping site: Fyodorovskoye_RU-Fyo_tower (No valid CH4 observations)
 -> Skipping site: Gunnarsholt_IS-Gun_tower (No valid CH4 observations)
 -> Skipping site: HJP02 Jack Pine_CA-HJP02_tower (No valid CH4 observations)
 -> Skipping site: HJP75 Jack Pine_CA-HJP75_tower (No valid CH4 observations)
 -> Skipping site: HJP94 Jack Pine_CA-HJP94_tower (N

Now make heat maps

In [6]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Correlation heatmaps: predictors vs ch4_flux_total

- Monthly (per site): rows = months 1..12, cols = predictors
- Seasonal (per site): rows = Winter/Spring/Summer/Autumn, cols = predictors
- Annual (ALL SITES TOGETHER): rows = sites, cols = predictors

Outputs (under OUT_DIR):
  /monthly/<site>_corr_monthly.csv
  /monthly/<site>_corr_heatmap_monthly.png
  /seasonal/<site>_corr_seasonal.csv
  /seasonal/<site>_corr_heatmap_seasonal.png
  /annual/annual_corr_by_site.csv
  /annual/annual_corr_heatmap_by_site.png
"""

import os
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import warnings

warnings.filterwarnings("ignore", category=FutureWarning)

# ----------------- Config -----------------
IN_CSV  = "/explore/nobackup/people/spotter5/anna_v/v2/v2_model_training_final.csv"
OUT_DIR = "/explore/nobackup/people/spotter5/anna_v/v2/corr_heatmaps_ch4"

PREDICTORS = [
    'EVI', 'NDVI', 'sur_refl_b01', 'sur_refl_b02', 'sur_refl_b03',
    'sur_refl_b07', 'NDWI', 'pdsi', 'srad', 'tmean_C', 'vap', 'vs',
    'co2_cont', 'ALT',
    'lai', 'fpar', 'Percent_NonTree_Vegetation',
    'Percent_NonVegetated', 'Percent_Tree_Cover',
    'sm_surface', 'sm_rootzone', 'snow_cover', 'snow_depth'
]

TARGET = "ch4_flux_total"
MIN_PAIRS = 12       # min (x,y) pairs required for Pearson r
ROT_X = 90           # rotate predictor labels for readability

# ----------------- Helpers -----------------
def corr_with_min_pairs(x: pd.Series, y: pd.Series, min_pairs=MIN_PAIRS) -> float:
    v = pd.DataFrame({"x": x, "y": y}).dropna()
    if len(v) < min_pairs:
        return np.nan
    return v["x"].corr(v["y"])

def season_label(month: int) -> str:
    if month in (12, 1, 2):   return "Winter"
    if month in (3, 4, 5):    return "Spring"
    if month in (6, 7, 8):    return "Summer"
    return "Autumn"

def ensure_numeric(df: pd.DataFrame, cols: list) -> pd.DataFrame:
    for c in cols:
        if c in df.columns:
            df[c] = pd.to_numeric(df[c], errors="coerce")
    return df

def plot_heatmap_annotated(matrix_df: pd.DataFrame, title: str, out_file: Path, annotate=True):
    """Blue→Red heatmap with optional r annotations in each cell."""
    n_rows, n_cols = matrix_df.shape
    fig_w = max(10, min(3 + 0.5 * n_cols, 40))
    fig_h = max(6,  2 + 0.45 * n_rows)
    fig, ax = plt.subplots(figsize=(fig_w, fig_h))

    im = ax.imshow(matrix_df.values, cmap="bwr", vmin=-1, vmax=1,
                   origin="upper", aspect="auto")

    cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label("Pearson r (predictor vs CH4)", fontsize=11)

    ax.set_xticks(np.arange(n_cols))
    ax.set_yticks(np.arange(n_rows))
    ax.set_xticklabels(matrix_df.columns, fontsize=11, rotation=ROT_X, ha="right")
    ax.set_yticklabels(matrix_df.index, fontsize=12)

    ax.set_xlabel("Predictor variable", fontsize=12)
    ax.set_ylabel("Group", fontsize=12)
    ax.set_title(title, fontsize=14, pad=12)

    # annotate
    if annotate:
        # turn off annotations if matrix is huge to keep files readable
        if n_rows * n_cols > 2200:
            annotate = False
        # smaller text if large but still within limit
        fz = 9 if n_rows <= 40 else 7

    if annotate:
        for i in range(n_rows):
            for j in range(n_cols):
                val = matrix_df.iloc[i, j]
                if pd.notna(val):
                    ax.text(j, i, f"{val:.2f}", ha="center", va="center",
                            color="black", fontsize=fz)

    fig.tight_layout()
    out_file.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_file, dpi=200)
    plt.close(fig)

# ----------------- Main -----------------
def main():
    print(f"Reading: {IN_CSV}")
    df = pd.read_csv(IN_CSV)

    if "flux_method" in df.columns:
        df = df[df["flux_method"] == "EC"].copy()

    required = ["site_reference", "year", "month", TARGET]
    missing_req = [c for c in required if c not in df.columns]
    if missing_req:
        raise ValueError(f"Missing required columns: {missing_req}")

    df["year"]  = pd.to_numeric(df["year"], errors="coerce").astype("Int64")
    df["month"] = pd.to_numeric(df["month"], errors="coerce").astype("Int64")
    df = df.dropna(subset=["site_reference", "year", "month"])
    df = df[(df["month"] >= 1) & (df["month"] <= 12)].copy()

    # keep predictors that actually exist
    preds_present = [p for p in PREDICTORS if p in df.columns]
    if not preds_present:
        raise ValueError("None of the specified predictors are present in the input CSV.")
    if len(preds_present) < len(PREDICTORS):
        print("Warning: missing predictors (ignored):",
              sorted(set(PREDICTORS) - set(preds_present)))

    df = ensure_numeric(df, [TARGET] + preds_present)
    df["season"] = df["month"].apply(season_label)

    # sites
    sites = df["site_reference"].dropna().unique()
    sites = sorted(sites, key=lambda x: str(x).lower())
    print(f"Found {len(sites)} sites.")

    # output dirs
    out_monthly  = Path(OUT_DIR) / "monthly"
    out_seasonal = Path(OUT_DIR) / "seasonal"
    out_annual   = Path(OUT_DIR) / "annual"
    for p in (out_monthly, out_seasonal, out_annual):
        p.mkdir(parents=True, exist_ok=True)

    # ---- Per-site monthly & seasonal ----
    for site in sites:
        sd = df[df["site_reference"] == site].copy()

        if sd[TARGET].dropna().shape[0] < MIN_PAIRS:
            print(f" -> Skipping {site}: insufficient {TARGET} data.")
            continue

        # Monthly (12 x predictors)
        months = list(range(1, 13))
        mon_mat = pd.DataFrame(index=months, columns=preds_present, dtype=float)
        for m in months:
            sub = sd[sd["month"] == m]
            for p in preds_present:
                mon_mat.loc[m, p] = corr_with_min_pairs(sub[p], sub[TARGET])
        mon_mat.index.name = "Month"
        mon_csv = out_monthly / f"{site}_corr_monthly.csv"
        mon_png = out_monthly / f"{site}_corr_heatmap_monthly.png"
        mon_mat.to_csv(mon_csv, float_format="%.4f")
        plot_heatmap_annotated(mon_mat, f"{site} — Monthly correlation (Predictors vs CH4)", mon_png)

        # Seasonal (4 x predictors)
        seasons = ["Winter", "Spring", "Summer", "Autumn"]
        sea_mat = pd.DataFrame(index=seasons, columns=preds_present, dtype=float)
        for s in seasons:
            sub = sd[sd["season"] == s]
            for p in preds_present:
                sea_mat.loc[s, p] = corr_with_min_pairs(sub[p], sub[TARGET])
        sea_mat.index.name = "Season"
        sea_csv = out_seasonal / f"{site}_corr_seasonal.csv"
        sea_png = out_seasonal / f"{site}_corr_heatmap_seasonal.png"
        sea_mat.to_csv(sea_csv, float_format="%.4f")
        plot_heatmap_annotated(sea_mat, f"{site} — Seasonal correlation (Predictors vs CH4)", sea_png)

        print(f" -> Saved monthly/seasonal heatmaps for {site}")

    # ---- Annual (ALL SITES TOGETHER): rows = sites, cols = predictors ----
    ann_rows = []
    valid_sites = []
    for site in sites:
        sd = df[df["site_reference"] == site].copy()
        if sd[TARGET].dropna().shape[0] < MIN_PAIRS:
            continue
        row = {}
        for p in preds_present:
            row[p] = corr_with_min_pairs(sd[p], sd[TARGET])
        ann_rows.append(row)
        valid_sites.append(site)

    if ann_rows:
        ann_mat = pd.DataFrame(ann_rows, index=valid_sites, columns=preds_present)
        ann_mat.index.name = "site_reference"
        ann_csv = out_annual / "annual_corr_by_site.csv"
        ann_png = out_annual / "annual_corr_heatmap_by_site.png"
        ann_mat.to_csv(ann_csv, float_format="%.4f")

        # annotate only if not too huge
        annotate = (ann_mat.shape[0] * ann_mat.shape[1] <= 2200)
        plot_heatmap_annotated(
            ann_mat,
            "Annual correlation (Predictors vs CH4) — rows = sites",
            ann_png,
            annotate=annotate
        )
        print(f" -> Saved combined annual matrix: {ann_csv}")
    else:
        print("No sites had sufficient data for the annual matrix.")

    print(f"\nDone. Results under:\n  {out_monthly}\n  {out_seasonal}\n  {out_annual}")

if __name__ == "__main__":
    main()


Reading: /explore/nobackup/people/spotter5/anna_v/v2/v2_model_training_final.csv
Found 188 sites.
 -> Skipping Abisko Stordalen birch forest_tower: insufficient ch4_flux_total data.
 -> Skipping Adventdalen_SJ-Adv_tower: insufficient ch4_flux_total data.
 -> Skipping Alberta - Western Peatland - LaBiche River,Black Spruce,Larch Fen_CA-WP1_tower: insufficient ch4_flux_total data.
 -> Skipping Alberta - Western Peatland - Poor Fen (Sphagnum moss)_CA-WP2_tower: insufficient ch4_flux_total data.
 -> Skipping Alberta - Western Peatland - Rich Fen  (Carex)_CA-WP3_tower: insufficient ch4_flux_total data.
 -> Skipping Anaktuvuk River Moderate Burn_US-An2_tower: insufficient ch4_flux_total data.
 -> Skipping Anaktuvuk River Severe Burn_US-An1_tower: insufficient ch4_flux_total data.
 -> Skipping Anaktuvuk River Unburned_US-An3_tower: insufficient ch4_flux_total data.
 -> Skipping Andoya_NO-And_tower: insufficient ch4_flux_total data.
 -> Saved monthly/seasonal heatmaps for ARM-NSA-Barrow_US-A10