In [None]:
import os, shutil, numpy as np, re, pandas as pd
from SALib.sample import morris

# ==============================
# CONFIG
# ==============================
BASE_DIR = "/scratch/hjh7hp/Watershed_22_2025_fall/Watershed22_with_new_summer/Sharadha_khola_watershed/1976_SA_1/salyan/model"
SOURCE_DEFS_DIR = os.path.join(BASE_DIR, "defs")
SA_DEFS_PARENT = os.path.join(BASE_DIR, "SA_defs_morris")
os.makedirs(SA_DEFS_PARENT, exist_ok=True)

RANDOM_SEED = 2025
NUM_TRAJECTORIES = 30
NUM_LEVELS = 10

unchanged_files = set([
    "basin_basin.def", "hill_setting1.def", "landuse_setting5.def",
    "stratum_crop.def", "stratum_grass.def", "stratum_shrub.def"
])

# Add 'zone_setting2010.def' to files_to_change
files_to_change = [f for f in os.listdir(SOURCE_DEFS_DIR)
                   if f.endswith(".def") and f not in unchanged_files]

#if 'zone_setting2010.def' not in files_to_change:
    #files_to_change.append('zone_setting2010.def')  # Ensure it's included

SUM_ONE_GROUPS = {
    "soil_loam_9.def": [('clay', 'sand', 'silt')],
    "soil_sand_10.def": [('clay', 'sand', 'silt')],
    "soil_sandy_loam_12.def": [('clay', 'sand', 'silt')],
    "soil_silt_loam_8.def": [('clay', 'sand', 'silt')],
    "soil_silty_clay_loam_3.def": [('clay', 'sand', 'silt')],
    "stratum_cwt_rhododendron_bgc.def": [
        ('epc.leaflitr_fcel','epc.leaflitr_flab','epc.leaflitr_flig'),
        ('epc.frootlitr_fcel','epc.frootlitr_flab','epc.frootlitr_flig'),
        ('epc.deadwood_fcel','epc.deadwood_flig')],
    "stratum_deciduous.def": [
        ('epc.leaflitr_fcel','epc.leaflitr_flab','epc.leaflitr_flig'),
        ('epc.frootlitr_fcel','epc.frootlitr_flab','epc.frootlitr_flig'),
        ('epc.deadwood_fcel','epc.deadwood_flig')],
    "stratum_eastern_white_pine.def": [
        ('epc.leaflitr_fcel','epc.leaflitr_flab','epc.leaflitr_flig'),
        ('epc.frootlitr_fcel','epc.frootlitr_flab','epc.frootlitr_flig'),
        ('epc.deadwood_fcel','epc.deadwood_flig')],
    "stratum_evergreen.def": [
        ('epc.leaflitr_fcel','epc.leaflitr_flab','epc.leaflitr_flig'),
        ('epc.frootlitr_fcel','epc.frootlitr_flab','epc.frootlitr_flig'),
        ('epc.deadwood_fcel','epc.deadwood_flig')],
    "stratum_localdeciduous.def": [
        ('epc.leaflitr_fcel','epc.leaflitr_flab','epc.leaflitr_flig'),
        ('epc.frootlitr_fcel','epc.frootlitr_flab','epc.frootlitr_flig'),
        ('epc.deadwood_fcel','epc.deadwood_flig')]
}

DAY_INT_PARAMS = {"epc.day_leafoff", "epc.day_leafon", "epc.ndays_expand", "epc.ndays_litfall"}

MULT1000_PARAMS = {
    "epc.vpd_close (x1000)",
    "epc.vpd_open (x1000)",
    "gsurf_intercept",
    "snow_light_ext_coef"
}

CN_RANGES = {"broadleaf": (20, 50, 45, 98),
             "pine": (20, 50, 45, 98),
             "alder": (20, 50, 45, 98),
             "grass": (20, 50, 45, 98)}


def detect_veg_type(fn):
    fn = fn.lower()
    if "pine" in fn or "evergreen" in fn: return "pine"
    elif "alder" in fn: return "alder"
    elif "grass" in fn: return "grass"
    return "broadleaf"

def enforce_sum_to_one(params, file_prefix):
    for grp in SUM_ONE_GROUPS.get(file_prefix + ".def", []):
        keys = [f"{file_prefix}_{p}" for p in grp]
        vals = [max(params.get(k, 0), 0) for k in keys]
        s = sum(vals)
        if s == 0:
            vals = [1.0 / len(grp)] * len(grp)
        else:
            vals = [v / s for v in vals]
        decimals = 8
        vals = [round(v, decimals) for v in vals[:-1]]
        last_val = round(1.0 - sum(vals), decimals)
        vals.append(last_val)
        total = sum(vals)
        if abs(total-1.0) > 1e-8:
            vals = [round(v + (1.0-total)/len(vals), decimals) for v in vals]
        for k, v in zip(keys, vals):
            params[k] = v
    return params

def enforce_cn(params, file_prefix):
    veg = detect_veg_type(file_prefix)
    leaf_min, leaf_max, lit_min, lit_max = CN_RANGES[veg]
    leaf_cn_key = f"{file_prefix}_epc.leaf_cn"
    leaflitr_cn_key = f"{file_prefix}_epc.leaflitr_cn"
    if leaf_cn_key in params:
        params[leaf_cn_key] = np.clip(params[leaf_cn_key], leaf_min, leaf_max)
    if leaflitr_cn_key in params:
        min_lit = max(lit_min, params.get(leaf_cn_key, leaf_min)+5)
        params[leaflitr_cn_key] = np.clip(params[leaflitr_cn_key], min_lit, lit_max)
    return params

def enforce_days(params, file_prefix):
    for d in DAY_INT_PARAMS:
        key = f"{file_prefix}_{d}"
        if key in params:
            params[key] = int(round(min(365,max(0,params[key]))))
    l_on = f"{file_prefix}_epc.day_leafon"
    l_off = f"{file_prefix}_epc.day_leafoff"
    if l_on in params and l_off in params:
        if params[l_on] >= params[l_off]:
            params[l_on] = min(params[l_on], 180)
            params[l_off] = max(params[l_on]+30, params[l_off])
            params[l_off] = min(365, params[l_off])
    return params

def enforce_multiple_1000(params):
    for k in params:
        for pname in MULT1000_PARAMS:
            if k.endswith("_" + pname) or k == pname:
                params[k] = int(np.floor(params[k] / 1000)*1000)
    return params

def enforce_all_constraints(params):
    for f in files_to_change:
        file_prefix = f.replace('.def','')
        params = enforce_sum_to_one(params, file_prefix)
        params = enforce_cn(params, file_prefix)
        params = enforce_days(params, file_prefix)
    params = enforce_multiple_1000(params)
    return params


def parse_param_line(line):
    # Update regex so it matches zone parameters too (including '(= lapse rate)' suffix)
    m = re.match(r'^\s*([-+]?\d*\.?\d+(?:[eE][-+]?\d+)?)[ \t]+([a-zA-Z0-9_\.]+(?: \(= lapse rate\))?)', line)
    return (float(m.group(1)), m.group(2), line[m.end():]) if m else (None, None, None)

def is_integer_param(file_prefix, param_name):
    # If any zone parameters are integers, handle here (most appear to be float)
    # Add additional logic if zone parameters should be integer-typed
    if param_name in DAY_INT_PARAMS:
        return True
    if param_name in MULT1000_PARAMS:
        return True
    # No additional integer params for the zone parameters needed
    full_key = f"{file_prefix}_{param_name}"
    return any([full_key.endswith("_" + p) for p in DAY_INT_PARAMS | MULT1000_PARAMS])

bounds = pd.read_csv(os.path.join(BASE_DIR, "Parameters_range_values_Final_with_zone.csv"))
param_names, bnds = [], []
for _, r in bounds.iterrows():
    try:
        lo, hi = float(r['lower limit']), float(r['upper limit'])
        if lo < hi:
            param_names.append(r['Parameter name'])
            bnds.append([lo, hi])
    except:
        continue

problem = {'num_vars': len(param_names), 'names': param_names, 'bounds': bnds}
X = morris.sample(problem, N=NUM_TRAJECTORIES, num_levels=NUM_LEVELS, seed=RANDOM_SEED)
raw_df = pd.DataFrame(X, columns=param_names)

adjusted_samples = []
for idx, row in raw_df.iterrows():
    params = row.to_dict()
    params = enforce_all_constraints(params)
    for k, v in params.items():
        if "_" in k:
            prefix, param = k.rsplit("_", 1)
            if is_integer_param(prefix, param):
                if param in MULT1000_PARAMS:
                    params[k] = int(np.floor(v / 1000)*1000)
                else:
                    params[k] = int(round(v))
            elif isinstance(v, float):
                params[k] = round(v, 8)
        elif isinstance(v, float):
            params[k] = round(v, 8)
    adjusted_samples.append(params)

adjusted_df = pd.DataFrame(adjusted_samples, columns=param_names)
adjusted_df.insert(0, "defs_set", [f"defs{i+1}" for i in range(len(adjusted_df))])
adjusted_df.to_csv(os.path.join(BASE_DIR, "defs_parameter_mapping.csv"), index=False)
adjusted_df.drop("defs_set", axis=1).to_csv(os.path.join(BASE_DIR, "morris_parameter_set_full.csv"), index=False)

def write_defs(sample_params, idx):
    defsdir = os.path.join(SA_DEFS_PARENT, f"defs{idx}")
    os.makedirs(defsdir, exist_ok=True)
    for f in unchanged_files:
        shutil.copy2(os.path.join(SOURCE_DEFS_DIR, f), os.path.join(defsdir, f))
    for f in files_to_change:
        file_prefix = f.replace('.def','')
        lines = open(os.path.join(SOURCE_DEFS_DIR, f)).read().splitlines(True)
        out = []
        for ln in lines:
            val, p, rest = parse_param_line(ln)
            if not p:
                out.append(ln)
                continue
            if p.endswith("default_ID"):
                out.append(ln)
                continue
            k = f"{file_prefix}_{p}"
            if k in sample_params:
                nv = sample_params[k]
                if is_integer_param(file_prefix, p):
                    if p in MULT1000_PARAMS:
                        nv = int(np.floor(nv / 1000) * 1000)
                    out.append(f"{int(nv)} {p}{rest}")
                elif isinstance(nv, float) and float(nv).is_integer():
                    out.append(f"{int(nv)} {p}{rest}")
                else:
                    out.append(f"{nv:.8f} {p}{rest}")
            else:
                out.append(ln)
        open(os.path.join(defsdir, f), 'w').writelines(out)
    return defsdir

maprec = []
for i, row in adjusted_df.iterrows():
    defsdir = write_defs(row.to_dict(), i + 1)
    maprec.append(row.to_dict())
    maprec[-1]['defs_set'] = f"defs{i+1}"
    print("Created:", defsdir)

NUM_DEFS_FOLDERS = len(adjusted_df)

pd.DataFrame(maprec).to_csv(os.path.join(BASE_DIR, "defs_parameter_mapping.csv"), index=False)
print("Mapping file written.")
