import argparse
import json
import math
from pathlib import Path

import ansys.fluent.core as pyfluent

DEFAULT_NAMED_EXPRESSION_PATH = Path("config/default_named_expressions.json")
REPORT_DEFINITIONS_PATH = Path("config/report_definitions.json")

CASE_JSON_FILENAME = "case.json"

NAMED_EXPRESSIONS = ["p_atm_pa", "temp_in_degc", "temp_base_degc", "phi_in_1", "k_solid_w_mk", "eps_base_1",
                     "u_in_m_s", "t_layer_m", "f_l_1", "f_w_1", "q_base_w_m2", "q_fan_w", "h_m", "w_m",
                     "k_singular_1", "eta_fan_1", "k_layer_w_mk", "t_hydrogel_m"]

ITER_MARGIN = 3
MIN_ITER_PI = 5


def unit_for_named_expr(name: str) -> str:
    if name.endswith("_degc"):  return "[C]"
    if name.endswith("_m_s"):   return "[m/s]"
    if name.endswith("_w_mk"):  return "[W/(m*K)]"
    if name.endswith("_m3_s"):  return "[m^3/s]"
    if name.endswith("_pa"):    return "[Pa]"
    if name.endswith("_m"):     return "[m]"
    if name.endswith("_w_m2"):  return "[W/(m^2)]"
    if name.endswith("_w"):     return "[W]"
    # dimensionless by default (phi, eps, etc.)
    return ""


def prepare_named_expressions(named_expressions: dict):
    named_expressions_list = []
    for key, value in named_expressions.items():
        ne = {}
        if key in NAMED_EXPRESSIONS:
            unit = unit_for_named_expr(key)
            sval = f"{value:g} {unit}" if isinstance(value, (int, float)) else str(value)
            ne["name"] = key
            ne["expression"] = sval
            named_expressions_list.append(ne)
    return named_expressions_list


def set_named_expressions(solver, evaporation, named_expressions: list):
    ne = solver.settings.setup.named_expressions
    for named_expression in named_expressions:
        if named_expression.get("evaporation", None):
            if not evaporation:
                continue
        name = named_expression["name"]
        ne.create(name)
        ne[name].definition = named_expression["expression"]


def set_air_properties(air, constant: bool):
    if not constant:
        air.specific_heat.option = "polynomial"
        air.specific_heat.polynomial = {
            "function_of": "temperature",
            "coefficients": [1047.63657, -0.372589265, 9.45304214E-4, -6.02409443E-7, 1.2858961E-10]
        }
        air.thermal_conductivity.option = "polynomial"
        air.thermal_conductivity.polynomial = {
            "function_of": "temperature",
            "coefficients": [-0.00227583562, 1.15480022E-4, -7.90252856E-8, 4.11702505E-11, -7.43864331E-15]
        }
        air.viscosity.option = "polynomial"
        air.viscosity.polynomial = {
            "function_of": "temperature",
            "coefficients": [-8.38278E-7, 8.35717342E-8, -7.69429583E-11, 4.6437266E-14, -1.06585607E-17]
        }
    else:
        # Fixed air properties at T=25degC
        air.specific_heat.option = "constant"
        air.specific_heat.value = 1006.32
        air.thermal_conductivity.option = "constant"
        air.thermal_conductivity.value = 0.0262469
        air.viscosity.option = "constant"
        air.viscosity.value = 1.84481e-05


def set_vapor_properties(vapor, constant: bool):
    if not constant:
        vapor.specific_heat.option = "polynomial"
        vapor.specific_heat.polynomial = {
            "function_of": "temperature",
            "coefficients": [4653.75324, -31.4458767, 0.141886474, -3.30846891E-4, 4.2664014E-7, -2.87854447E-10,
                             7.93865279E-14]
        }
        vapor.thermal_conductivity.option = "polynomial"
        vapor.thermal_conductivity.polynomial = {
            "function_of": "temperature",
            "coefficients": [0.0018875329, 3.57073326E-5, 6.42783151E-8, 5.35642899E-12, -9.97388219E-15]
        }
        vapor.viscosity.option = "polynomial"
        vapor.viscosity.polynomial = {
            "function_of": "temperature",
            "coefficients": [5.30681701E-6, -1.07574032E-8, 1.15361668E-10, -1.1119871E-13, 3.91677879E-17]
        }
    else:
        # Fixed vapor properties at T=25degC
        vapor.specific_heat.option = "constant"
        vapor.specific_heat.value = 1911.82
        vapor.thermal_conductivity.option = "constant"
        vapor.thermal_conductivity.value = 0.0184333
        vapor.viscosity.option = "constant"
        vapor.viscosity.value = 9.70092e-6


def set_report_definitions(rd, evaporation: bool, report_definitions: dict):
    for report_definition in report_definitions:
        name = report_definition["name"]
        if not evaporation and name in ["q_evap_w", "phi_out_1"]:
            continue

        rp_type = getattr(rd, report_definition["type"])
        rp_type.create(name)
        for key, value in report_definition.items():
            if key != "type":
                setattr(rp_type[name], key, value)


def _clamp(x, lo, hi):
    return lo if x < lo else hi if x > hi else x


def compute_report_def(rd, report_definition: str):
    return rd.compute(report_defs=[report_definition])[0][report_definition][0]


def run_case(src_dir: Path):
    print("=== Load JSON and Mesh ===")
    # JSON to dict
    case_json_path = Path.cwd() / CASE_JSON_FILENAME
    case = json.loads(case_json_path.read_text())

    # Load configurations
    default_named_expressions = json.loads((src_dir / DEFAULT_NAMED_EXPRESSION_PATH).read_text())
    report_definitions = json.loads((src_dir / REPORT_DEFINITIONS_PATH).read_text())

    # Aliases
    geom = case["mesh"]["geom"]
    solver_params = case["solver_params"]

    case_name = case["name"]
    mesh_name = case["mesh"]["name"]
    var_name = case["variable_name"]
    var_vals = case["var_values"]
    var_unit = unit_for_named_expr(var_name)

    if geom["dimension"] == 3:
        dim = pyfluent.Dimension.THREE
    else:
        dim = pyfluent.Dimension.TWO
    solver = pyfluent.launch_fluent(
        precision=pyfluent.Precision.DOUBLE,
        dimension=dim
    )
    # By default, it should use the available CPUs
    r = solver.settings
    t = solver.tui

    # General settings
    r.file.batch_options.confirm_overwrite = True
    r.file.batch_options.exit_on_error = True
    is_nat = "nat_conv" in str(geom.get("type", ""))
    if is_nat:
        print("The case is natural convection")

    # Mesh Loading
    r.file.read_case(file_name=f"{mesh_name}.msh.h5")

    # General Settings
    general = r.setup.general
    general.units.set_units(
        quantity="temperature", units_name="C", scale_factor=1.0, offset=273.15
    )

    # Operating Conditions
    if is_nat:
        # Gravity from orientation (XY is the plane in 2D)
        orient = geom.get("orientation")
        if orient == "horizontal":
            gravity_vec = [0.0, -9.81, 0.0]
        elif orient == "inverted":
            gravity_vec = [0.0, 9.81, 0.0]
        elif orient == "vertical":
            gravity_vec = [9.81, 0.0, 0.0]
        else:
            gravity_vec = [0.0, -9.81, 0.0]
            print(f"[WARN] Unknown orientation '{orient}', defaulting gravity={gravity_vec}")

        oc = general.operating_conditions
        oc.gravity.enable = True
        oc.gravity.components = gravity_vec

    # Reference Values
    ref_vals = r.setup.reference_values
    if ref_vals.depth.is_active():
        ref_vals.depth = geom["w_m"]

    # Named Expressions
    geom_named_expressions = prepare_named_expressions(geom)
    case_named_expressions = prepare_named_expressions(case)
    named_expressions = geom_named_expressions + case_named_expressions + default_named_expressions
    set_named_expressions(solver, case["evaporation"], named_expressions)

    # Load UDF
    print("=== Load UDF ===")
    # It looks like 2025 R1 does not have udf implemented so we use TUI
    # r.setup.user_defined.load(udf_library_name="libudf")
    if case["evaporation"]:
        udf = t.define.user_defined
        udf.user_defined_memory(1)
        udf.compiled_functions("load", "libudf")
        udf.function_hooks("initialization", "init_udm::libudf")

    print("=== Settings ===")

    # Models
    models = r.setup.models
    models.energy.enabled = True
    if case["evaporation"]:
        models.species.model.option = 'species-transport'
        models.species.options.diffusion_energy_source = True

    # Materials
    mat = r.setup.materials

    # Solids
    mat.solid["aluminum"].name = "base-material"
    mat.solid["base-material"].thermal_conductivity.option = "expression"
    mat.solid["base-material"].thermal_conductivity.expression = "k_base_w_mk"

    if case["evaporation"]:
        # No resistance material at interface
        mat.solid.create("no-resistance-material")
        mat.solid["no-resistance-material"].thermal_conductivity.option = "constant"
        mat.solid["no-resistance-material"].thermal_conductivity.value = 1e10

        # Hydrogel layer in the 3d case
        mat.solid.create("hydrogel")
        mat.solid["hydrogel"].thermal_conductivity.option = "constant"
        mat.solid["hydrogel"].thermal_conductivity.value = 0.6

        # Mixture Properties
        mat.mixture["mixture-template"].name = "moist-air"
        moist_air = mat.mixture["moist-air"]
        moist_air.species.volumetric = ["h2o", "air"]
        moist_air.density.option = "ideal-gas"
        moist_air.specific_heat.option = "mixing-law"
        moist_air.thermal_conductivity.option = "ideal-gas-mixing-law"
        moist_air.viscosity.option = "ideal-gas-mixing-law"
        if not case["cst_diff_coeff"]:
            moist_air.mass_diffusivity.option = "user-defined"
            moist_air.mass_diffusivity.user_defined_function = "D_va::libudf"
        else:
            moist_air.mass_diffusivity.option = "constant-dilute-appx"
        mat.fluid.delete("air")
        mat.fluid.delete("nitrogen")
        mat.fluid.delete("oxygen")

        # Air Properties
        air = moist_air.species.material["air"]
        set_air_properties(air, case["cst_air_props"])

        # Water Vapor Properties
        vapor = moist_air.species.material["water-vapor"]
        set_vapor_properties(vapor, case["cst_vapor_props"])

    else:
        # Dry Case
        air = mat.fluid["air"]
        air.density.option = "ideal-gas"
        set_air_properties(air, case["cst_air_props"])

    # Cell Zones Conditions
    general.operating_conditions.operating_pressure = "p_atm_pa"

    # Boundary Conditions
    bcs = r.setup.boundary_conditions
    if not is_nat:
        velocity_inlet = bcs.velocity_inlet["velocity-inlet"]
        velocity_inlet.momentum.velocity_magnitude.value = "u_in_m_s"
        velocity_inlet.thermal.temperature.value = "temp_in_degc"
    pressure_outlet = bcs.pressure_outlet["pressure-outlet"]
    pressure_outlet.thermal.backflow_total_temperature.value = "temp_in_degc"
    base = bcs.wall["base"]
    if case["heating_type"] == "T_base":
        base.thermal.thermal_condition = "Temperature"
        base.thermal.temperature.value = "temp_base_degc"
    elif case["heating_type"] == "q_base":
        base.thermal.thermal_condition = "Heat Flux"
        base.thermal.heat_flux.value = "q_base_w_m2"
    else:
        raise ValueError(f'wrong heating_type: {case["heating_type"]}')

    if case["evaporation"]:
        if not is_nat:
            velocity_inlet.species.species_mass_fraction["h2o"].value = "Y_in"
        pressure_outlet.species.backflow_species_mass_fraction["h2o"].value = "Y_in"

        wall_interface = bcs.wall["wall-interface"]
        if not wall_interface.species():
            wall_interface = bcs.wall["wall-interface-shadow"]
        # Apply shell conduction for the 3D case with hydrogel thickness
        if geom["dimension"] == 3 and geom["t_hydrogel_m"] != 0:
            wall_interface.thermal.enable_shell_conduction = True
            wall_interface.thermal.thin_wall.resize(size=2)
            # First layer is the no resistance material, starting from the fluid size
            wall_interface.thermal.thin_wall[0].thickness = geom["t_layer_m"]
            wall_interface.thermal.thin_wall[0].material = "no-resistance-material"
            wall_interface.thermal.thin_wall[0].qdot.option = "udf"
            # This always give me an error:
            # wall_interface.thermal.conduction_layers[0].qdot.option = 'udf'
            # To modify shell conduction model settings use define/models/shell-conduction/settings menu.
            # Error: GENERAL-CAR-CDR: invalid argument [1]: improper list
            # Error Object: ()
            wall_interface.thermal.thin_wall[0].qdot.udf = "wall_qdot_evap::libudf"
            wall_interface.thermal.thin_wall[1].thickness = geom["t_hydrogel_m"]
            wall_interface.thermal.thin_wall[1].material = "hydrogel"
        else:
            wall_interface.thermal.material = "no-resistance-material"
            wall_interface.thermal.wall_thickness.value = "t_layer_m"
            wall_interface.thermal.heat_generation_rate.option = "udf"
            wall_interface.thermal.heat_generation_rate.udf = "wall_qdot_evap::libudf"
        wall_interface.species.species_boundary_conditions["h2o"] = "Specified Mass Fraction"
        wall_interface.species.species_mass_fraction_or_flux["h2o"].value = "Y_sat"

    # Methods
    # Keep defaults methods for now

    # Report Definitions
    rd = r.solution.report_definitions
    set_report_definitions(rd, case["evaporation"], report_definitions)

    # Convergence Conditions
    res = r.solution.monitor.residual
    eqs = res.equations()
    for key, value in eqs.items():
        if key == "continuity":
            value["absolute_criteria"] = solver_params["cont_conv_crit"]
        elif key == "energy":
            value["absolute_criteria"] = solver_params["energy_conv_crit"]
        elif key in ["k", "omega", "epsilon"]:
            value["absolute_criteria"] = solver_params["turb_conv_crit"]
        elif key == "h2o":
            value["absolute_criteria"] = solver_params["h2o_conv_crit"]
        elif "velocity" in key:
            value["absolute_criteria"] = solver_params["vel_conv_crit"]
        else:
            raise ValueError(f'unsupported equation conv. criteria: {key}')
    conv_cond = r.solution.monitor.convergence_conditions
    conv_cond.frequency = solver_params["conv_freq"]
    base_res = conv_cond.convergence_reports.create("base_res")
    base_res.stop_criterion = solver_params["q_base_conv_crit"]
    base_res.print = True
    if case["heating_type"] == "q_base":
        base_res.report_defs = "temp_base_degc"
    elif case["heating_type"] == "T_base":
        base_res.report_defs = "q_base_w"
    else:
        raise ValueError(f'wrong heating_type: {case["heating_type"]}')

    # Write Input Summary and Base Case
    print("=== Write Input Summary and Base Case ===")
    r.results.report.summary(write_to_file=True, file_name="report.sum")
    r.file.write_case(file_name=f'{case_name}_base.cas.h5')

    # # Reload the saved Case (workaround to turbulent viscosity ratio exploding)
    # r.file.read_case(file_name=f'{case_name}_base.cas.h5')

    # Initialize
    print("===== Initialize =====")
    init = r.solution.initialization
    init.standard_initialize()
    init.initialization_type = "hybrid"
    init.hybrid_init_options.general_settings.iter_count = solver_params["n_init"]
    init.hybrid_initialize()

    # Run
    run = r.solution.run_calculation
    ne = r.setup.named_expressions

    # Steps
    for idx, value in enumerate(var_vals, start=1):
        var_tag = str(value).replace('.', 'p')
        print(f"; ===== STEP {idx}: set {var_name} = {value} {var_unit} =====")

        # Change parameter
        ne[var_name].definition = f"{value:g} {var_unit}"

        # Solve
        prev_iteration = rd.compute(report_defs=["iteration"])[0]["iteration"][0]
        run.iterate(iter_count=solver_params["n_iter"])

        # Write all output parameters and case data
        r.parameters.output_parameters.write_all_to_file(file_name=f"out/params_{var_name}={var_tag}.out",
                                                         append_data=False)
        r.file.write_data(file_name=f"case_data/case_{var_name}={var_tag}.cas.h5")

        # Stop if the maximum number of iteration is reached
        iteration = rd.compute(report_defs=["iteration"])[0]["iteration"][0]
        if (iteration - prev_iteration) >= solver_params["n_iter"]:
            raise RuntimeError(f"maximum number of iteration is reached for step: {var_name} = {value} {var_unit}")

    print("===== END =====")


if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--src_dir", required=True)
    args = ap.parse_args()
    run_case(Path(args.src_dir))
