In [None]:
import os
import pandas as pd
import numpy as np
from SALib.analyze import morris as morris_analyze
import matplotlib.pyplot as plt

# --- CONFIG ---
BASE_DIR = "/scratch/hjh7hp/Watershed_22_2025_fall/Watershed22_with_new_summer/Sharadha_khola_watershed/1976_SA_1/salyan/model"
PARAM_BOUNDS_CSV = os.path.join(BASE_DIR, "Parameters_range_values_Final_with_zone.csv")
INPUT_CSV = os.path.join(BASE_DIR, "All_Parameters_with_Objective_functions.csv")
RESULTS_DIR = os.path.join(BASE_DIR, "Morris_results")
os.makedirs(RESULTS_DIR, exist_ok=True)
TOP_N = 20
TOP_N_DIR = os.path.join(RESULTS_DIR, f'top_{TOP_N}')
os.makedirs(TOP_N_DIR, exist_ok=True)
METRICS = ['NSE', 'RMSE', 'R2', 'logNSE']
EXCLUDE_COLS = set(['job','defs','NSE','RMSE','R2','logNSE'])

def wrap_label(label, width=12):
    parts = label.split('_')
    wrapped = []
    for part in parts:
        while len(part) > width:
            wrapped.append(part[:width])
            part = part[width:]
        wrapped.append(part)
    return '\n'.join(wrapped)

print("Loading data and filtering parameter columns...")
metrics_df = pd.read_csv(INPUT_CSV)
param_bounds_df = pd.read_csv(PARAM_BOUNDS_CSV)

all_potential_param_cols = []
for col in param_bounds_df['Parameter name']:
    if col not in EXCLUDE_COLS and col in metrics_df.columns:
        try:
            metrics_df[col].astype(float)
            all_potential_param_cols.append(col)
        except Exception:
            print(f"Excluding parameter '{col}' (non-numeric values detected).")
numeric_param_cols = all_potential_param_cols

print(f"Number of numeric parameter columns to use: {len(numeric_param_cols)}")

# Parameter bounds dictionary
used_param_bounds = {
    p: tuple(param_bounds_df[param_bounds_df['Parameter name'] == p][['lower limit', 'upper limit']].values[0])
    for p in numeric_param_cols
}

metrics_df_numeric = metrics_df.copy()
for col in numeric_param_cols:
    metrics_df_numeric[col] = pd.to_numeric(metrics_df_numeric[col], errors='coerce')

valid_mask = metrics_df_numeric[METRICS].notna().any(axis=1)
param_mask = ~metrics_df_numeric[numeric_param_cols].isna().any(axis=1)
metrics_df_cleaned = metrics_df_numeric.loc[valid_mask & param_mask, :]

# CORRECT HERE
n_cleaned = int(metrics_df_cleaned.shape[0])
if n_cleaned == 0:
    print("No usable simulations remain after filtering for numeric parameter columns.")
    exit()
print(f"Rows usable for Morris analysis: {n_cleaned}")

traj_len = int(len(numeric_param_cols)) + 1
n_traj = n_cleaned // traj_len
n_expected = n_traj * traj_len
df_morris = metrics_df_cleaned.iloc[:n_expected, :]
print(f"Analysis uses {df_morris.shape[0]} rows, {n_traj} full trajectories.")

problem = {
    'num_vars': len(numeric_param_cols),
    'names': numeric_param_cols,
    'bounds': [used_param_bounds[p] for p in numeric_param_cols]
}
results_tables = []
top_results_tables = []
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']

for idx, metric_col in enumerate(METRICS):
    if metric_col not in df_morris.columns:
        print(f"Metric '{metric_col}' not found. Skipping.")
        continue
    mask_metric = ~df_morris[metric_col].isna()
    X = df_morris.loc[mask_metric, numeric_param_cols].values.astype(float)
    Y = df_morris.loc[mask_metric, metric_col].values.astype(float)
    n_metric = X.shape[0]
    traj_metric = n_metric // traj_len
    n_metric_expected = traj_metric * traj_len
    if traj_metric == 0:
        print(f"Not enough complete trajectories for {metric_col}. Skipping.")
        continue
    X = X[:n_metric_expected, :]
    Y = Y[:n_metric_expected]
    num_levels = 4
    results = morris_analyze.analyze(
        problem,
        X,
        Y,
        num_levels=num_levels,
        print_to_console=False
    )
    results_df = pd.DataFrame({
        'Parameter': problem['names'],
        'mu_star': results['mu_star'],
        'sigma': results['sigma'],
        'mu': results['mu'],
        'mu_star_conf': results['mu_star_conf'],
        'metric': metric_col,
        'lower_limit': [problem['bounds'][i][0] for i in range(len(problem['names']))],
        'upper_limit': [problem['bounds'][i][1] for i in range(len(problem['names']))]
    }).sort_values('mu_star', ascending=False).reset_index(drop=True)
    results_tables.append(results_df)
    metric_csv_all = os.path.join(RESULTS_DIR, f'morris_sensitivity_{metric_col}_all.csv')
    results_df.to_csv(metric_csv_all, index=False)
    top_df = results_df.iloc[:TOP_N, :].copy()
    top_results_tables.append(top_df)
    metric_csv_top = os.path.join(TOP_N_DIR, f'top_{TOP_N}_morris_sensitivity_{metric_col}.csv')
    top_df.to_csv(metric_csv_top, index=False)
    plt.figure(figsize=(12, 5))
    params_sorted = [wrap_label(p, 12) for p in top_df['Parameter']]
    mu_star_sorted = top_df['mu_star']
    bars = plt.bar(range(len(params_sorted)), mu_star_sorted, color=colors[idx % len(colors)], edgecolor='black', alpha=0.85)
    plt.ylabel('Morris μ*', fontsize=14, fontweight='bold')
    plt.title(f'Top {TOP_N} Morris Sensitivity ({metric_col})', fontsize=15, fontweight='bold', pad=12)
    plt.xticks(range(len(params_sorted)), params_sorted, rotation=0, ha='center', fontsize=8, color='black')
    for idx_bar, bar in enumerate(bars):
        height = bar.get_height()
        plt.annotate(f"{height:.2f}", xy=(bar.get_x() + bar.get_width() / 2, height),
                    xytext=(0, 4), textcoords="offset points", ha='center', va='bottom', fontsize=11)
    plt.xlabel('Parameter (ranked)', fontsize=15, fontweight='bold')
    plt.tight_layout(rect=[0, 0.03, 1, 1])
    plot_path = os.path.join(TOP_N_DIR, f'top_{TOP_N}_morris_{metric_col}.png')
    plt.savefig(plot_path, dpi=300)
    plt.close()
if results_tables:
    results_all = pd.concat(results_tables, ignore_index=True)
    results_all.to_csv(os.path.join(RESULTS_DIR, f'morris_sensitivity_all_metrics_ranked.csv'), index=False)
if top_results_tables:
    top_results_all = pd.concat(top_results_tables, ignore_index=True)
    top_results_all.to_csv(os.path.join(TOP_N_DIR, f'top_{TOP_N}_morris_sensitivity_all_metrics_ranked.csv'), index=False)
print(f"\nTop {TOP_N} Morris sensitivity results and bar plots saved in:\n {TOP_N_DIR}")
