In [None]:
using PyCall
sage_all = pyimport("sage.all")

In [None]:
using PyCall
sage_all = pyimport("sage.all")

In [None]:
using Random
using IntervalArithmetic
using Gurobi
using JuMP
using MIPVerify
using MIPVerify:relu
using MIPVerify:NNetFormat,evaluate_network,evaluate_network_withoutNorm,evaluate_network_multiple,info
using MIPVerify:prep_data_file,get_matrix_params
using MIPVerify:JuMPLinearType
using MAT
using Images
using Printf

In [None]:
struct VerifyInfo
    index::Int
    time::Float64
    adv_found::Bool
    adv_not_found::Bool
    verified::Bool
    verified_res::Int
end

In [None]:
#Fmipverify 在784维度找反例

In [None]:
#adv_found:表示发现反例，true表示真反例，false 表示假反例
#adv_not_found:true表示求解器无法验证，说明是没有找到反例，是鲁棒的。false说明求解器可以给出解
#verified：true表示得到了验证,false表示没有被验证
#verified_res:1表示鲁棒，-1表示不鲁棒,0表示未知
function robust_verify_mnist784()
    misclassified = 0
    result = VerifyInfo[]
    round_error_class=Int[]
    mnist = MIPVerify.read_datasets("MNIST")
    nn784 = MIPVerify.get_example_network_params("F16MNIST24")
    for index in 1:5513
        total_time = 0
        try
            total_time = @elapsed begin
            ##参数读取部分代码
                sample_image = MIPVerify.get_image(mnist.test.images, index)
                sample_label = MIPVerify.get_label(mnist.test.labels, index)
                predicted_output = Float16.(sample_image) |> nn784
                predicted_output_index = argmax(predicted_output)
                if predicted_output_index != sample_label + 1
                    #println("Sample $index misclassified.")
                    misclassified = misclassified+1
                    continue
                end
                #println("the label : $sample_label ----》predicted as $predicted_output_index")
                @assert (predicted_output_index - 1) == sample_label

                flat_layer = Flatten(4)
                flat_img = flat_layer(sample_image)
                input_size = 784
                @assert(
                    length(flat_img) == input_size, 
                    "the length of flat img: $(length(flat_img)) is not equal to the input size: $input_size"
                )

                ##约束求解部分
                sgSolver = sage_all.MixedIntegerLinearProgram(solver="gurobi")
                # 再创建gurobi部分的优化问题
                model = Model(Gurobi.Optimizer)
                #model = Model(Gurobi.Optimizer, OutputFlag=0)
                set_optimizer_attribute(model, "OutputFlag", 0)
                nn_itv = MIPVerify.get_example_network_params_withSageSolver("F16MNIST24itv",sgSolver)
                d1 =MIPVerify.set_robustradius_checkrobust_withsage(
                    sgSolver,
                    model,
                    nn_itv,
                    flat_img,
                    MIPVerify.LInfNormBoundedPerturbationFamily(0.15),
                    0.1
                )

                ##读取MILP编码bound值
                #output for jump
                output_jump = d1[:Output]
                output_var = sgSolver.new_variable()
                output_itv = d1[:Output_itv]

                for i in 1:size(d1[:Output_itv],1)
                    sgSolver.add_constraint(output_var[i] <= output_itv[i].var_up)
                    sgSolver.add_constraint(output_var[i] >= output_itv[i].var_lo)
                end

                ##设置输出约束
                label_index = [1,2,3,4,5,6,7,8,9,10]    
                target_index = [i for i in label_index if i!=predicted_output_index]

                output_var_list = [output_var[i] for i in 1:10]
                @assert(
                all(item != predicted_output_index for item in target_index), 
                "the length of flat img: $(target_index) should not equal to the : $(predicted_output_index)"
                )

                MIPVerify.set_max_indexes_withsage(sgSolver, output_jump , output_var_list, target_index,margin=floatmin(Float64))

                @elapsed sgSolver.solve()

                solved_input = [convert(Float16,sgSolver.get_values(d1[:PerturbedInput_itv][i])) for i in 1:input_size]
                perturbation_predicted_output = solved_input |> nn784
                perturbation_predicted_index = argmax(perturbation_predicted_output)
                #println("the label : $sample_label ----》perturted as $perturbation_predicted_index")
                if(perturbation_predicted_index==predicted_output_index)
                    #发现虚假反例
                    false_adv_res=VerifyInfo(index,total_time,false,false,false,0)
                    push!(result,false_adv_res)
                else
                    #发现真实反例
                    true_adv_res=VerifyInfo(index,total_time,true,false,true,-1)
                    push!(result,true_adv_res)
                end
            end  

        catch e
            #println("处理样本 $index 时出错: $e")
            verified_robust = VerifyInfo(index,total_time,false,true,true,1)
            push!(result,verified_robust)
        end
    end
    run_statistics = Dict(
        :result => result,
        :misclassified => misclassified,
        :round_error_class => round_error_class
    )
    return run_statistics
end


In [None]:
return784 = robust_verify_mnist784()

In [None]:
# 创建一个空字典
statistics_dict = Dict{String, Int}()
statistics_dict["total_sample"] = 5513-return784[:misclassified]

# 统计 verified_t
verified_t = count(x -> x.verified == true && x.verified_res == 1, res784)
statistics_dict["verified_t"] = verified_t

statistics_dict["verified_unknow"] = statistics_dict["total_sample"] - statistics_dict["verified_t"] 
# 打印字典
println(statistics_dict)
