In [2]:
from pyomo.environ import *
from pyomo.contrib.piecewise import PiecewiseLinearFunction
from pyomo.core.base import TransformationFactory
from pyomo.opt import SolverStatus, TerminationCondition
#from pyomo.kernel import *
#from pyomo.core.kernel.piecewise_library.transforms_nd import piecewise_nd
#from pyomo.core.kernel.piecewise_library import util as pwutil  # 方便自动生成剖分
from dataclasses import dataclass, field
from typing import Callable, List, Dict, Tuple
import numpy as np
import math
import matplotlib.pyplot as plt
import bisect
import itertools as it
#from pyomo.environ import ConcreteModel, Var, Constraint, Expression

In [26]:
def evaluate_Q_at(model, first_stg_vars, first_stg_vals, solver):
    """
    Given y = y_val , minimize obj_expr and return v(y).
    This function temporarily increments the objective and clears it after completion, without changing the model structure.
    """
    # Clear any remaining As/pw/obj (to prevent it from being left over from the previous round)
    del_components(model)
    
    for u, v in zip(first_stg_vars, first_stg_vals):
        u.fix(value(v))
    model.obj = Objective(expr=model.obj_expr, sense=minimize)
    results = solver.solve(model, tee=False)

    status_ok = (results.solver.status == SolverStatus.ok)
    term_ok = (results.solver.termination_condition == TerminationCondition.optimal)
    if not (status_ok and term_ok):
        # check if solution okay
        raise RuntimeError(f"Scenario evaluate at y={first_stg_vals} not optimal: "
                           f"status={results.solver.status}, term={results.solver.termination_condition}")

    v_opt = value(model.obj_expr)
    # clear temporarily objective
    model.del_component('obj')
    for u in first_stg_vars:
        u.unfix()
    return v_opt

def del_components(model):
    for comp in ['obj', 'As', 'pw', 'pw_fun', 'pw_As', 'pw_link']:
        if hasattr(model, comp):
            model.del_component(comp)

def corners_from_bounds(firt_stg_vars):
    """给一组 Pyomo Var 生成所有 box 角点（每维取 lb/ub）"""
    bounds = []
    for y in firt_stg_vars:
        lb, ub = y.lb, y.ub
        if lb is None or ub is None:
            raise ValueError(f"{y.name} 缺少上下界，无法生成角点")
        bounds.append((float(lb), float(ub)))
    # 每维挑 lb/ub 的笛卡尔积
    return list(it.product(*[(lb, ub) for (lb, ub) in bounds]))


def add_nd_piecewise(
    model,
    firt_stg_vars,                 # [x1, x2, ..., xN]  模型里的 first stage Var
    points,                 # [(c11,...,c1N), (c21,...,c2N), ...]  所有节点（同维度 N）
    values,                 # 与 points 对齐的一维 list/array，或 dict{point_tuple: value}
    name="pw",
    relation="==",          # '==', '>='(下界/内逼近), '<='(上界/外逼近)
    round_ndigits=12,       # 为避免浮点比较问题，对坐标做轻微 round
):
    """
    返回 (z, pw)。z 是 Var（或你可将 make_z_var=False 改成返回 Expression）。
    """

    # 维度检查
    if len(points) == 0:
        raise ValueError("points 不能为空")
    N = len(firt_stg_vars)
    for pt in points:
        if len(pt) != N:
            raise ValueError(f"points 中出现与 x_vars 维度不一致的点: {pt}")
        
    del_components(model)

    # 统一坐标的浮点表示，避免查表时精度问题
    def keyize(coords):
        return tuple(round(float(c), round_ndigits) for c in coords)

    norm_points = [keyize(pt) for pt in points]

    # 把 values 统一成 dict 表
    if isinstance(values, dict):
        table = {keyize(k): float(v) for k, v in values.items()}
        # 确保每个 point 都有值
        miss = [pt for pt in norm_points if pt not in table]
        if miss:
            raise KeyError(f"values 缺少这些点的取值: {miss[:5]}{' ...' if len(miss)>5 else ''}")
    else:
        # 视作与 points 对齐的一维序列
        if len(values) != len(points):
            raise ValueError("values 长度应与 points 数量一致（或传 dict）")
        table = {pt: float(v) for pt, v in zip(norm_points, values)}

    # 查表函数（仅在节点上被调用）
    def _f_from_table(*coords):
        return table[keyize(coords)]

    # 创建并挂到模型
    pw = PiecewiseLinearFunction(points=norm_points, function=_f_from_table, name=f"{name}_fun")
    model.add_component(pw.name, pw)

    # 组成表达式
    pw_expr = pw(*firt_stg_vars)

    As = Var(name=f"{name}_As")
    model.add_component(As.name, As)
    if relation == "==":
        link = Constraint(expr=As == pw_expr)
    elif relation == ">=":
        link = Constraint(expr=As >= pw_expr)
    elif relation == "<=":
        link = Constraint(expr=As <= pw_expr)
    else:
        raise ValueError("relation 只能是 '==', '>=', '<='")
    model.add_component(f"{name}_link", link)

    TransformationFactory('contrib.piecewise.convex_combination').apply_to(model)
        
    return As, pw


def clone_and_get_vars(m_old, first_stage_vars):
    """
    克隆模型，并按名字列表返回新模型中的变量组件
    var_names: ['y', 'x', ...]  (只能是容器名，不是元素)
    """
    m_new = m_old.clone()
    first_stage_vars_new = []
    for v in first_stage_vars:
        v_new = m_new.find_component(v.name)
        if v_new is None:
            raise KeyError(f"在新模型里找不到变量 '{v.name}'")
        first_stage_vars_new.append(v_new)
    return m_new, first_stage_vars_new

# delete repeated nodes
def unique_points(points, atol=1e-9):
    out = []
    for p in points:
        if not any(all(abs(a-b) <= atol for a,b in zip(p, q)) for q in out):
            out.append(p)
    return out


def nc_underest(model_list, first_stg_vars_list, m_tmpl_list, target_nodes, picture_shown=False, v_list=False, tolerance=1e-8):
    """
    Parameters:
        #bounds (list): contains 2 float which is lower and upper bound of variable
        model_list (list): model with submodels corresponds to each scenario
        first_stg_var (list): 
        m_tmpl_list (list): [template model, template model first stg variables list]
        target_nodes (float): number of target nodes
        tolerance (float): decide when to stop

    Returns: delta (float): delta
             errors (float): hausdorff error
             y_nodes (list): y node (to make plot)
             as_nodes_list[0] (list): As node value (to make plot)
             ms_list[0] (float): ms for first scenario (to make plot)
    """
    N = len(model_list)
    as_nodes_list = [[] for _ in range(N)]
    ms_list = [None] * N
    new_nodes_list = [None] * N # Storing potential new nodes
    As_min_list = []
    under_tol = 1e-8
    
    add_node_history = []

    # set up solver
    solver = SolverFactory('gurobi')
    solver.options.update({
        'MIPGap': 1e-8,         
        'MIPGapAbs': 0.0,       
        'FeasibilityTol': 1e-9,  
        'IntFeasTol':     1e-9,  
        'OptimalityTol': 1e-9,
        'NumericFocus': 2,      
        'ScaleFlag':    1,       
        'Presolve': 2,          
        'Method':  -1,          
        'Crossover': -1,       
        'NonConvex': 2, 
    })
    
    ######### if we want to plot figures #########
    if picture_shown:
        y_vals = np.linspace(new_lb, new_ub, 100)
        Qs_vals_list = [None] * N
        for i in range(N):
            Qs_vals_list[i] = [v_list[i](y) for y in y_vals]
        Qs_arr = np.array(Qs_vals_list, dtype=float, ndmin=2)  
        Qs_vals_sum = Qs_arr.sum(axis=0)
    #############################################
    
    # start from corner nodes
    first_stg_nodes = corners_from_bounds(first_stg_vars_list[0])
    for i in range(N):
        as_nodes_list[i].extend(
            evaluate_Q_at(model_list[i], first_stg_vars_list[i], node, solver) for node in first_stg_nodes
        )
    print('corner nodes are ', first_stg_nodes)
    print('as_nodes_list are ', as_nodes_list)
    '''
    for node in first_stg_nodes:
        for i in range(N):
            as_nodes_list[i].append(evaluate_Q_at(model_list[i], node, solver))
    '''

    if target_nodes <= len(first_stg_nodes):
        print('target_nodes number should be larger than ',len(first_stg_nodes))
        return

    print('Start from ',len(first_stg_nodes),' nodes')
    print('The goal is to get ',target_nodes,' nodes')
    k_list = []
    for k in range(len(first_stg_nodes)+1,target_nodes+1):
        print('##################################################')
        print('##################################################')
        print('Start adding node ',k)
        k_list.append(k)
        for i in range(N):
            print(' ')
            print('Solving scenario ',i)
            # define piecewise function for each scenario
            del_components(model_list[i])
            As, _ = add_nd_piecewise(
            model_list[i], first_stg_vars_list[i], first_stg_nodes, as_nodes_list[i],
            name="pw", relation="=="
            )
            model_list[i].obj = Objective(expr=model_list[i].obj_expr - As, sense=minimize)
            #results = SolverFactory("gurobi").solve(model_list[i], tee=False)
            results = solver.solve(model_list[i], tee=False)
            if (results.solver.status != SolverStatus.ok) or \
               (results.solver.termination_condition != TerminationCondition.optimal):
                print("⚠ There may be problems with the solution")
                
            ms_list[i] = value(model_list[i].obj)
            # insert new nodes
            new_nodes_list[i] = tuple(value(v) for v in first_stg_vars_list[i])
            print('new node is ',tuple(value(v) for v in first_stg_vars_list[i]))
            print('ms is ',value(model_list[i].obj))

        #####################################################################################################
        ##################################### picture show ##################################################
            if picture_shown:
                print(' ')
                print('The plot for scenario ',i)
                print('The potential y_star is ',value(model_list[i].y))
                print('ms is ',value(model_list[i].obj))
                # set up plot parameters
                y_nodes_arr = np.array(y_nodes)
                as_nodes_arr = np.array(as_nodes_list[i])
                y_star_i = value(model_list[i].y)
                ms_i = value(model_list[i].obj)
                ## plot the figure
                plt.figure(figsize=(8, 5))
                plt.plot(y_vals, Qs_vals_list[i], label=fr'$Qs_{i}$', color='red')
                plt.plot(y_nodes_arr, as_nodes_arr+ms_i,label=fr'$As_{i} underest$',color='red', marker='o',linestyle='--',alpha=0.5)
                plt.plot(y_nodes_arr, as_nodes_arr,label=fr'$As_{i}$',color='blue', marker='o',linestyle='--',alpha=0.5)
                plt.axvline(x=y_star_i, color='purple', linestyle='--')
                plt.xlim(new_lb, new_ub)
                plt.xlabel('y')
                plt.ylabel('value')
                plt.title(fr"Plot for scenario {i} for {k} nodes")
                plt.legend()
                plt.grid(True)
                plt.tight_layout()
                plt.show()
        ###################################### picture show end #############################################
        #####################################################################################################

        # define and solve the sum model
        arr = np.array(as_nodes_list, dtype=float, ndmin=2)  
        assum_nodes = arr.sum(axis=0) 

        # build As_sum model and solve for possible node pf max error
        model_sum, model_sum_first_stg_vars = clone_and_get_vars(m_tmpl_list[0], m_tmpl_list[1])
        del_components(model_sum)
        As, pw = add_nd_piecewise(
        model_sum, model_sum_first_stg_vars, first_stg_nodes, assum_nodes,
        name="pw", relation="=="
        )
        model_sum.obj = Objective(expr= As, sense=minimize)
        #results = SolverFactory("gurobi").solve(model_sum, tee=False)
        results = solver.solve(model_sum, tee=False)
        if (results.solver.status == SolverStatus.ok) and (results.solver.termination_condition == TerminationCondition.optimal):
            pass
        else:
            print("Sum model doesn't get solved normally")

        # get the output
        As_min = results.problem.lower_bound
        node_star = tuple(value(v) for v in model_sum_first_stg_vars) 

        if (node_star is None) or (node_star in first_stg_nodes):
            avg = []
            for j in range(len(first_stg_nodes[0])):
                comp_vals = [node[j] for node in first_stg_nodes]
                avg.append(sum(comp_vals) / len(comp_vals))
            node_star = tuple(avg)
            As_min = value(pw(*node_star))  
        errors_node_star = 0
        for i in range(N):
            errors_node_star += evaluate_Q_at(model_list[i], first_stg_vars_list[i], node_star, solver)
        errors_node_star = abs(As_min - errors_node_star)

        #####################################################################################################
        ##################################### picture show ##################################################
        if picture_shown:
            print(' ')
            print('The plot for As_sum')
            print('The potential y_star is ',y_star)
            print('error is ',errors_y_star)
            # set up plot parameters
            y_nodes_arr = np.array(y_nodes)
            assum_nodes_arr = np.array(assum_nodes)
            ms_sum = sum(ms_list)
            ## plot the figure
            plt.figure(figsize=(8, 5))
            plt.plot(y_vals, Qs_vals_sum, label=fr'$Qs_sum$', color='red')
            plt.plot(y_nodes_arr, assum_nodes_arr+ms_sum,label=fr'$As_sum underest$',color='red', marker='o',linestyle='--',alpha=0.5)
            plt.plot(y_nodes_arr, assum_nodes_arr,label=fr'$As_sum$',color='blue', marker='o',linestyle='--',alpha=0.5)
            plt.axvline(x=y_star, color='purple', linestyle='--')
            plt.xlim(new_lb, new_ub)
            plt.xlabel('y')
            plt.ylabel('value')
            plt.title(fr"Plot for As_sum for {k} nodes")
            plt.legend()
            plt.grid(True)
            plt.tight_layout()
            plt.show()
        ###################################### picture show end #############################################
        #####################################################################################################

        sum_ms = sum(ms_i for ms_i in ms_list)
        
        print('Sum *****************************************')
        print('error at y_star is ',errors_node_star)
        print('y_star is ',node_star)
        print('ms_list and sum_ms is ',ms_list,sum_ms)
        if errors_node_star > abs(sum_ms):
            new_node = node_star
            print('new node choosen from error')
        else:
            min_index = np.argmin(ms_list)
            new_node = new_nodes_list[min_index]
            print('new node choosen from ms')
        As_min_list.append(As_min+sum_ms)
        add_node_history.append(new_node)
        print('new node is',new_node)
        print('Current As_min is',As_min_list[-1])
        print('*****************************************')
        print('')
        #######################################################              

        print('current nodes are ', first_stg_nodes)
        print('as_nodes_list are ', as_nodes_list)
        print('new_node is', new_node)
        first_stg_nodes.append(new_node)
        for i in range(N):
            as_nodes_list[i].append(evaluate_Q_at(model_list[i], first_stg_vars_list[i], new_node, solver))

    # define and solve the sum model
    arr = np.array(as_nodes_list, dtype=float, ndmin=2)  
    assum_nodes = arr.sum(axis=0) 

    # build As_sum model and solve for possible node pf max error
    model_sum, model_sum_first_stg_vars = clone_and_get_vars(m_tmpl_list[0], m_tmpl_list[1])
    del_components(model_sum)
    As, _ = add_nd_piecewise(
    model_sum, model_sum_first_stg_vars, first_stg_nodes, assum_nodes,
    name="pw", relation="=="
    )
    model_sum.obj = Objective(expr= As, sense=minimize)
    #results = SolverFactory("gurobi").solve(model_sum, tee=False)
    results = solver.solve(model_sum, tee=False)
    if (results.solver.status == SolverStatus.ok) and (results.solver.termination_condition == TerminationCondition.optimal):
        pass
    else:
        print("Sum model doesn't get solved normally")

    # get the output
    output_lb = results.problem.lower_bound + sum(ms_list)
    print('lower bound is ',output_lb)
    print('node is ',tuple(value(v) for v in model_sum_first_stg_vars) )
    
    return output_lb, first_stg_nodes, [k_list, As_min_list, add_node_history]