In [None]:
using LinearAlgebra
using Distributions
using Optim
using Random
using StatsFuns
using JuMP
using MosekTools
using StatsBase
using SparseArrays 
using FileIO
using JLD2
using Plots
using LaTeXStrings
using DataFrames, Colors
using StatsPlots   

include("Params_PLD.jl")
include("Data_Generation_PLD.jl")
include("Estimation_PLD.jl")
# include("Estimation_PLD_Fast.jl")
include("Models_PLD.jl")
include("Evaluation_PLD.jl")
include("Implement_All_Methods_PLD.jl")
include("Figures_PLD.jl")

In [None]:
function partition_interval(lb_input,ub_input,num_c::Integer)
    if num_c < 1
        throw(ArgumentError("num_c must be >= 1"))
    end
    # 生成 num_c+1 个分割点（包含两端）
    edges = collect(range(lb_input, stop=ub_input, length = num_c + 1))
    lowers = edges[1:end-1]
    uppers = edges[2:end]
    return lowers, uppers
end

In [None]:
Params = get_default_params_PLD()
# Params = get_Wang_Qi_Shen_params_PLD()
N = Params["N"] # number of products
N_x = Params["N_x"] # dimension of product features
c_l = Params["c_l"] 
d_r = Params["d_r"]
rev_gap = Params["rev_gap"]
N_u = Params["N_u"] # dimension of customer features
S_test = Params["S_test"] # test data size
N_Max = Params["N_Max"] # maximum assortment size
N_nonzero = Params["N_nonzero"] # number of nonzero entries in A
Time_Limit = Params["Time_Limit"] # time limit for optimization
dual_norm = Params["dual_norm"] # dual norm for robust optimization
norm_bounds = Params["norm_bounds"]
gamma_list = Params["gamma_list"] # list of gamma values for robust optimization
psi_lb = Params["psi_lb"] # lower bound for psi
psi_ub = Params["psi_ub"] # upper bound for psi
phi_lb = Params["phi_lb"]   # lower bound for phi
phi_ub = Params["phi_ub"]  # upper bound for phi
# num_c = Params["num_c"] # number of customer segments
instances = Params["instances"] # number of instances
seed = Params["seed"] # random seed
coef_para_Input = Params["coef_this"] # coefficient for data generation

In [None]:
S_train_list = Params["S_train_all"] # training data size
is_ridge = Params["is_ridge"] # whether to use ridge regression
S_train_list = [500] # for quick testing
instances = 20 # for quick testing

In [None]:
function Generate_Wang_Qi_Max_True_Paras_Multi_Class(N_x,N_u,N_nonzero,coef_Params,num_c)
    alp0_lbs, alp0_ups = partition_interval(coef_Params.alp0_lb,coef_Params.alp0_ub,num_c)
    alp_lbs, alp_ups = partition_interval(coef_Params.alp_lb ,coef_Params.alp_ub,num_c)
    beta_lbs, beta_ups = partition_interval(coef_Params.beta_lb,coef_Params.beta_ub,num_c)
    A_lbs, A_ups = partition_interval(coef_Params.A_lb,coef_Params.A_ub,num_c)

    theta_true_all = Dict()
    for c in 1:num_c
        coef_Params_c = (alp0_lb=alp0_lbs[c],
                        alp0_ub=alp0_ups[c],
                        alp_lb=alp_lbs[c],
                        alp_ub=alp_ups[c],
                        beta_lb=beta_lbs[c],
                        beta_ub=beta_ups[c],
                        A_lb=A_lbs[c],
                        A_ub=A_ups[c],
                        r0_lb=coef_Params.r0_lb,
                        r0_ub=coef_Params.r0_ub,
                        r_lb=coef_Params.r_lb,
                        r_ub=coef_Params.r_ub)
        theta_true_c, r_params_c = Generate_Wang_Qi_Max_True_Paras(N_x,N_u,N_nonzero,coef_Params_c)
        theta_true_all["class=$(c)"] = theta_true_c
    end
    # --- 步骤 2: 生成收益参数 r ---
    r0 = rand(Uniform(coef_Params.r0_lb, coef_Params.r0_ub))
    r = rand(Uniform(coef_Params.r_lb, coef_Params.r_ub), N_x)
    r_params = (r0=r0, r=r)
    return theta_true_all, r_params
end

In [None]:
Random.seed!(seed)
is_Wang_Qi_Shen = true;
is_same_util_para = true;
if is_Wang_Qi_Shen
    project_dir = "Model_Mis_Wang_Qi_Shen_N=$(N)_N_x=$(N_x)_N_u=$(N_u)_N_nonzero=$(N_nonzero)_dr=$(d_r[1])_seed=$(seed)"
else
    project_dir = "Model_Mis_N=$(N)_N_x=$(N_x)_N_u=$(N_u)_N_nonzero=$(N_nonzero)_dr=$(d_r[1])_seed=$(seed)"
end
if is_same_util_para
    println("Generate data with the same utility parameters for all instances.")
    # theta_true_all_Fixed, r_params_Fixed = Generate_Wang_Qi_Max_True_Paras_Multi_Class(N_x,N_u,N_nonzero,coef_para_Input,num_c);
    project_dir = string(project_dir, "_Same_Util_Para/")
else
    println("Generate data with different utility parameters for all instances.")
    project_dir = string(project_dir, "_Diff_Util_Para/")
end
current_dir = pwd()
parent_dir = dirname(current_dir)
grand_pa_dir = dirname(parent_dir)
data_dir = string(dirname(grand_pa_dir), "/Data/Product_Line_Design/")

data_dir = string(data_dir,project_dir)
if !isdir(data_dir)
    mkpath(data_dir)
end
println("Data directory: ", data_dir)
save(string(data_dir, "Params.jld2"), Params);

In [None]:
function compute_w(params,z_input)
    alpha0 = params.alpha0
    alpha = params.alpha
    beta = params.beta
    A = params.A
    nu0 = alpha0 + beta' * z_input;
    nu = alpha .+ A * z_input;
    return nu0,nu
end

In [None]:
function output_results(S_train,lambda,data_dir,instances,fig_display)
    Input_Data = load(string(data_dir, "Input_Data_S=$(S_train)_lambda=$(lambda)_num_c=$(num_c).jld2"));
    RST_True_All = load(string(data_dir, "RST_True_S=$(S_train)_lambda=$(lambda)_num_c=$(num_c).jld2"));
    RST_ETO_All = load(string(data_dir, "RST_ETO_S=$(S_train)_lambda=$(lambda)_num_c=$(num_c).jld2"));
    RST_RO_All = load(string(data_dir, "RST_RO_S=$(S_train)_lambda=$(lambda)_num_c=$(num_c).jld2"));

    gamma_list = sort([parse(Float64, split(k, "=")[end]) for k in keys(RST_RO_All["ins=1"])])
    gamma_list = gamma_list
    # println("Gamma list: ", gamma_list)

    obj_True, obj_ETO, obj_RO = obtain_obj(RST_True_All, RST_ETO_All, RST_RO_All, instances, gamma_list);
    println("S=$(S_train),lambda=$(lambda),obj True:",round.(mean(obj_True),digits=4))
    println("S=$(S_train),lambda=$(lambda),obj ETO:",round.(mean(obj_ETO),digits=4))
    println("S=$(S_train),lambda=$(lambda),obj RO:",round.(mean(obj_RO,dims=1),digits=4))
    println()
    profit_True, profit_ETO, profit_RO = obtain_profits(RST_True_All, RST_ETO_All, RST_RO_All, instances, gamma_list);
    println("S=$(S_train),lambda=$(lambda),profit True:",round.(mean(profit_True),digits=4))
    println("S=$(S_train),lambda=$(lambda),profit ETO/True:",round.(mean(profit_ETO)/mean(profit_True),digits=4))
    println("S=$(S_train),lambda=$(lambda),profit RO/True:",round.(mean(profit_RO,dims=1)./mean(profit_True),digits=4))
    println("S=$(S_train),lambda=$(lambda),profit RO/ETO:",round.(mean(profit_RO,dims=1)./mean(profit_ETO),digits=4))
    # fig_name = string(data_dir, "RPLD_vs_ETOPLD_S=$(S_train)_lambda=$lambda.pdf")
    # include_std = false
    # line_plot_RPLD_vs_ETOPLD(profit_ETO,profit_RO,gamma_list,include_std,fig_name,fig_display)
    return profit_True, profit_ETO, profit_RO
end

In [None]:
function plot_box(S_train,lambda,data_dir,chosen_indices,is_display)
    Input_Data = load(string(data_dir, "Input_Data_S=$(S_train)_lambda=$(lambda).jld2"));
    RST_True_All = load(string(data_dir, "RST_True_S=$(S_train)_lambda=$(lambda).jld2"));
    RST_ETO_All = load(string(data_dir, "RST_ETO_S=$(S_train)_lambda=$(lambda).jld2"));
    RST_RO_All = load(string(data_dir, "RST_RO_S=$(S_train)_lambda=$(lambda).jld2"));

    gamma_list = sort([parse(Float64, split(k, "=")[end]) for k in keys(RST_RO_All["ins=1"])])
    gamma_list = gamma_list
    # println("Gamma list: ", gamma_list)

    obj_True, obj_ETO, obj_RO = obtain_obj(RST_True_All, RST_ETO_All, RST_RO_All, instances, gamma_list);
    println("S=$(S_train),lambda=$(lambda),obj True:",round.(mean(obj_True),digits=4))
    println("S=$(S_train),lambda=$(lambda),obj ETO:",round.(mean(obj_ETO),digits=4))
    println("S=$(S_train),lambda=$(lambda),obj RO:",round.(mean(obj_RO,dims=1),digits=4))
    println()
    profit_True, profit_ETO, profit_RO = obtain_profits(RST_True_All, RST_ETO_All, RST_RO_All, instances, gamma_list);
    println("S=$(S_train),lambda=$(lambda),profit True:",round.(mean(profit_True),digits=4))
    println("S=$(S_train),lambda=$(lambda),profit ETO/True:",round.(mean(profit_ETO)/mean(profit_True),digits=4))
    println("S=$(S_train),lambda=$(lambda),profit RO/True:",round.(mean(profit_RO,dims=1)./mean(profit_True),digits=4))
    println("S=$(S_train),lambda=$(lambda),profit RO/ETO:",round.(mean(profit_RO,dims=1)./mean(profit_ETO),digits=4))

    Profit_ETO_All_Ins = profit_ETO./mean(profit_ETO)
    Profit_RO_All_Ins = Dict();
    for g_index in 1:length(gamma_list)
        gamma=gamma_list[g_index]
        Profit_RO_All_Ins["gamma=$(gamma)"] = profit_RO[:,g_index]./mean(profit_ETO)
    end
    
    gamma_chosen = gamma_list[chosen_indices]
    data = [Profit_ETO_All_Ins, [Profit_RO_All_Ins["gamma=$(gamma)"] for gamma in gamma_chosen]...]
    labels = ["ETO"; ["RO($gamma)" for gamma in gamma_chosen]]
    fig_name = string(data_dir, "Boxplot_RPLD_vs_ETOPLD_S=$(S_train)_lambda=$(lambda)_Normalize.pdf")
    boxplot_RPLD_vs_ETOPLD(data,labels,fig_name,is_display)

end

In [None]:
function Generate_Wang_Qi_Max_True_Data_Multi_Class(N_x, N_u, n_sample, m,theta_true,u_lb,u_ub)
    # 初始化存储
    X = Vector{Matrix{Float64}}(undef, n_sample); # n 个 m×d 矩阵
    Z = Matrix{Float64}(undef, n_sample, N_u);      # n×p 矩阵
    Y = Vector{Int}(undef, n_sample);             # n 维向量

    for i in 1:n_sample
        z_i = rand(Uniform(u_lb, u_ub), N_u)
        Z[i, :] = z_i

        X_i = zeros(m, N_x)
        for j in 1:m
            # X_i[j, :] = rand(Uniform(0.0, 1.0), d)
            X_i[j, :] = rand(0.0:1.0, N_x)
        end
        X[i] = X_i

        # --- 计算选择概率 ---
        # 根据公式 (2.1): Pz(x; θ*) = exp(Uz(x)) / (V0 + exp(Uz(x)))
        # 论文中 V0 = exp(U0) = 1 (默认选项效用权重归一化为1)
        V0 = 1.0
        utilities = zeros(m)
        for j in 1:m
            x_ij = X_i[j, :] # 第 j 个产品的设计
            # 计算效用 Uz(x_ij) = α₀* + <α*, x_ij> + <β*, z_i> + x_ij^T * A* * z_i
            utility = theta_true.alpha0 +
                        dot(theta_true.alpha, x_ij) +
                        dot(theta_true.beta, z_i) +
                        dot(x_ij, theta_true.A * z_i) # x_ij^T * A* * z_i
            utilities[j] = utility
        end

        # 计算分子 exp(Uz(x_ij))
        exp_utilities = exp.(utilities)
        # 计算分母 (V0 + sum(exp(Uz(x_il))))
        denominator = V0 + sum(exp_utilities)

        # 计算选择每个产品 j 的概率
        prob_choose_product = exp_utilities ./ denominator
        # 计算选择默认选项 (索引 0) 的概率
        prob_choose_default = V0 / denominator

        # 构建完整的概率向量 [P(选择默认), P(选择产品1), ..., P(选择产品m_actual)]
        choice_probs = vcat(prob_choose_default, prob_choose_product)
        # 选择结果: 0 表示默认选项, 1 表示第一个产品, ..., m_actual 表示第 m_actual 个产品
        y_i = sample(0:m, Weights(choice_probs))
        Y[i] = y_i
    end
    return X,Y,Z
end

In [None]:
function n_sample_per_class(S_train, class_probs, num_c)
    n_samples = zeros(num_c)
    for c in 1:(num_c-1)
        n_samples[c] = floor(S_train * class_probs[c])
    end
    n_samples[num_c] = S_train - sum(n_samples[1:(num_c-1)])
    n_samples = Int.(n_samples)
    return n_samples
end

In [None]:
# 返回一个长度为 num_c 的概率向量，满足几何公比 r，且和为 1。
# p[i] = a * r^(i-1)，其中 a = (1-r)/(1-r^num_c) (当 r != 1)
function geo_probs(num_c::Integer, r::Real=1.0)
    if num_c < 1
        throw(ArgumentError("num_c must be >= 1"))
    end
    # 当 r 很接近 1 时，数值上视作 r == 1（均匀分布）
    if isapprox(r, 1.0; atol=1e-12)
        return fill(1.0/num_c, num_c)
    end
    denom = 1.0 - r^num_c
    # 对于 r != 1，首项 a：
    a = (1.0 - r) / denom
    probs = [a * r^(i-1) for i in 1:num_c]
    return probs
end

In [None]:
function Generate_Data_this_Same_Para_Multi_Class(num_c,S_train,N_x,N_u,N_Max,theta_true_all_Fixed, r_params_Fixed)
    class_probs = fill(1.0/num_c, num_c) # Generate the same class probabilities for all classes
    # class_probs = geo_probs(num_c, 0.5)
    n_samples = n_sample_per_class(S_train, class_probs, num_c) # number of samples per class
    lowers = -0.1*ones(num_c)
    uppers = 0.1*ones(num_c)

    X_train_all = Dict()
    Y_train_all = Dict()
    Z_train_all = Dict()
    for c in 1:num_c
        u_lb_c = lowers[c]
        u_ub_c = uppers[c]
        n_sample_c = n_samples[c]
        theta_true_c = theta_true_all_Fixed["class=$(c)"]
        X_c,Y_c,Z_c = Generate_Wang_Qi_Max_True_Data_Multi_Class(N_x, N_u, n_sample_c, N_Max,theta_true_c,u_lb_c,u_ub_c)
        X_train_all["class=$(c)"] = X_c
        Y_train_all["class=$(c)"] = Y_c
        Z_train_all["class=$(c)"] = Z_c
    end
    X_train = vcat([X_train_all["class=$(c)"] for c in 1:num_c]...);
    Y_train = vcat([Y_train_all["class=$(c)"] for c in 1:num_c]...);
    Z_train = vcat([Z_train_all["class=$(c)"] for c in 1:num_c]...);

    asorrtment_train = Array{Vector{Int64}}(undef,S_train)
    for s in 1:S_train
        asorrtment_train[s] = collect(1:N_Max)
    end

    Input_Data_this = Dict(
        "X_train" => X_train,
        "Y_train" => Y_train,
        "Z_train" => Z_train,
        "class_probs" => class_probs,
        "n_samples" => n_samples, 
        "lowers" => lowers,
        "uppers" => uppers,
        "theta_true_all" => theta_true_all_Fixed,
        "r_params" => r_params_Fixed,
        "asorrtment_train" => asorrtment_train
    )
    return Input_Data_this
end

In [None]:
function Get_Input_Data_Multi_Class(Input_Data_this)
    X_train = Input_Data_this["X_train"]
    Y_train = Input_Data_this["Y_train"]
    Z_train = Input_Data_this["Z_train"]
    class_probs = Input_Data_this["class_probs"]
    n_samples = Input_Data_this["n_samples"]
    lowers = Input_Data_this["lowers"]
    uppers = Input_Data_this["uppers"]
    theta_true_all = Input_Data_this["theta_true_all"]
    r_params = Input_Data_this["r_params"]
    asorrtment_train = Input_Data_this["asorrtment_train"]

    return X_train, Y_train, Z_train, asorrtment_train, class_probs, n_samples, lowers, uppers, theta_true_all, r_params
end

In [None]:
function main_process(S_train_list,lambda,instances,num_c)
    theta_true_all_Fixed, r_params_Fixed = Generate_Wang_Qi_Max_True_Paras_Multi_Class(N_x,N_u,N_nonzero,coef_para_Input,num_c);
    config = Dict()
    config["theta_true_all_Fixed"] = theta_true_all_Fixed
    config["r_params_Fixed"] = r_params_Fixed
    save(string(data_dir, "config_num_c=$(num_c)_lambda=$(lambda).jld2"), config);
    for S_train in S_train_list
        println("********** S_train = ",S_train," **********")
        Input_Data = Dict()
        RST_True_All = Dict()
        RST_ETO_All = Dict()
        RST_RO_All = Dict()
        ins = 1
        while ins <= instances
            # ******** Data generation *************
            Input_Data_this = Generate_Data_this_Same_Para_Multi_Class(num_c,S_train,N_x,N_u,N_Max,theta_true_all_Fixed, r_params_Fixed);
            X_train, Y_train, Z_train, asorrtment_train, class_probs, n_samples, lowers, uppers, theta_true_all, r_params = Get_Input_Data_Multi_Class(Input_Data_this);
            # ******** Estimation *************
            theta_hat = Estimation_This(N_Max,N_x,N_u,Y_train,X_train,Z_train, asorrtment_train,is_ridge, lambda)
            
            #******** Generate test data *************
            class_chosen = sample(1:num_c, Weights(class_probs))
            u_lb = lowers[class_chosen]
            u_ub = uppers[class_chosen]
            theta_true = theta_true_all["class=$(class_chosen)"]
            X_t,Y_t,Z_test = Generate_Wang_Qi_Max_True_Data_Multi_Class(N_x, N_u, S_test, N_Max,theta_true,u_lb,u_ub)

            nu0_true,nu_true = compute_w(theta_true,Z_test[1,:])  
            nu_all_true = [nu0_true;nu_true]

            nu0_hat,nu_hat = compute_w(theta_hat,Z_test[1,:])  
            nu_all_hat = [nu0_hat;nu_hat]
            
            if any(isnan, nu_all_hat)
                println("Estimate contains NaN values.")
                continue
            end
            # if norm(vec(nu_all_true .- nu_all_hat),2) >= norm_bounds
            #     println("Estimate is too far from true parameters.")
            #     continue
            # end

            Input_Data_this["theta_hat"] = theta_hat
            Input_Data_this["nu_true"] = nu_all_true
            Input_Data_this["nu_hat"] = nu_all_hat
            Input_Data_this["class_chosen"] = class_chosen
            Input_Data_this["Z_test"] = Z_test
            Input_Data["ins=$(ins)"] = Input_Data_this

            # ******** True Model *************
            theta_Input = theta_true
            RST_True,status_True = solve_ETO_This(S_test,N,N_x,theta_Input,theta_true,r_params,c_l,d_r,rev_gap,num_c,Time_Limit,Z_test)
            # println("Oracle: status = ",status_True,",obj=",RST_True["obj"][1])
            if status_True != "OPTIMAL"
                println("Warning: The true model did not reach optimality")
                continue
            end
            RST_True_All["ins=$(ins)"] = RST_True

            # ******** ETO Model *************
            RST_ETO,status_ETO = solve_ETO_This(S_test,N,N_x,theta_hat,theta_true,r_params,c_l,d_r,rev_gap,num_c,Time_Limit,Z_test)
            # println("ETO: status = ",status_ETO,",obj=",RST_ETO["obj"][1])
            if status_ETO != "OPTIMAL"
                println("Warning: The ETO model did not reach optimality")
                continue
            end
            RST_ETO_All["ins=$(ins)"] = RST_ETO
            
            # ******** RO Model *************
            RST_RO_this = Dict()
            gamma = gamma_list[1]
            RST_RO,status_RO = solve_RO_this(S_test,N,N_x,theta_hat,theta_true,r_params,c_l,d_r,rev_gap,num_c,Time_Limit,Z_test,gamma,psi_lb,psi_ub,phi_lb,phi_ub)
            # println("gamma = $gamma, RO: status = ",status_RO,",obj=",RST_RO["obj"][1])
            if status_RO != "OPTIMAL"
                println("Warning: The RO model did not reach optimality")
                continue
            end
            ratio = abs(RST_RO["obj"][1] - RST_ETO["obj"][1])/abs(RST_ETO["obj"][1])
            # ratio = abs(RST_RO["profit"][1] - RST_ETO["profit"][1])/abs(RST_ETO["profit"][1])
            if ratio > 1e-3
                println("Warning: The RO obj is not equivalent to ETO obj: ETO_Obj=",RST_ETO["obj"][1],",RO_Obj=",RST_RO["obj"][1])
                continue
            end
            RST_RO_this[string("gamma=",gamma)] = RST_RO

            for g_index in 2:length(gamma_list)
                gamma = gamma_list[g_index]
                RST_RO,status_RO = solve_RO_this(S_test,N,N_x,theta_hat,theta_true,r_params,c_l,d_r,rev_gap,num_c,Time_Limit,Z_test,gamma,psi_lb,psi_ub,phi_lb,phi_ub)
                # println("gamma = $gamma, RO: status = ",status_RO,",obj=",RST_RO["obj"][1])
                RST_RO_this[string("gamma=",gamma)] = RST_RO
            end
            RST_RO_All["ins=$(ins)"] = RST_RO_this

            println("******* ins = ",ins,"*********")
            ins = ins + 1
        end
        save(string(data_dir, "Input_Data_S=$(S_train)_lambda=$(lambda)_num_c=$(num_c).jld2"), Input_Data);
        save(string(data_dir, "RST_True_S=$(S_train)_lambda=$(lambda)_num_c=$(num_c).jld2"), RST_True_All);
        save(string(data_dir, "RST_ETO_S=$(S_train)_lambda=$(lambda)_num_c=$(num_c).jld2"), RST_ETO_All);
        save(string(data_dir, "RST_RO_S=$(S_train)_lambda=$(lambda)_num_c=$(num_c).jld2"), RST_RO_All);
    end
end

### Lambda = 0.001

In [None]:
lambda = 0.001
num_c = 4
main_process(S_train_list,lambda,instances,num_c)         

In [None]:
S_train = S_train_list[1]
RST_True_All = load(string(data_dir, "RST_True_S=$(S_train)_lambda=$(lambda)_num_c=$(num_c).jld2"));
RST_ETO_All = load(string(data_dir, "RST_ETO_S=$(S_train)_lambda=$(lambda)_num_c=$(num_c).jld2"));
profit_True = zeros(instances);
profit_ETO = zeros(instances);
for ins in 1:instances
    profit_True[ins] = mean(RST_True_All["ins=$(ins)"]["profit"])
    profit_ETO[ins] = mean(RST_ETO_All["ins=$(ins)"]["profit"])
end
println("profit True:",round.(mean(profit_True),digits=4))
println("profit ETO/True:",round.(mean(profit_ETO)/mean(profit_True),digits=4))

In [None]:
profit_True, profit_ETO, profit_RO = output_results(S_train,lambda,data_dir,instances,true);

In [None]:
# lambda = 0.001
# num_c = 8
# main_process(S_train_list,lambda,instances,num_c) 

In [None]:
# profit_True, profit_ETO, profit_RO = output_results(S_train,lambda,data_dir,instances,true);

In [None]:
# lambda = 0.001
# num_c = 16
# main_process(S_train_list,lambda,instances,num_c) 

In [None]:
# lambda = 0.001
# num_c = 12
# main_process(S_train_list,lambda,instances,num_c)