In [None]:
using MIPVerify
using Gurobi
using JuMP
using Images
using Printf
function print_summary(d::Dict)
    # Helper function to print out output
    obj_val = JuMP.objective_value(d[:Model])
    solve_time = JuMP.solve_time(d[:Model])
    println("Objective Value: $(obj_val), Solve Time: $(@sprintf("%.2f", solve_time))")
end

function view_diff(diff::Array{<:Real, 2})
    n = 1001
    colormap("RdBu", n)[ceil.(Int, (diff .+ 1) ./ 2 .* n)]
end

In [None]:
# colorview(Gray, sample_image[1, :, :, 1]*10)
# Gray.(images[index]*10)  # Display the image
# sample_image = Float16.(sample_image)
# predicted_output = sample_image |> n1
# predicted_index = predicted_output |> MIPVerify.get_max_index
# label_index = [1,2,3,4,5,6,7,8,9,10]    
# target_index = [i for i in label_index if i!=predicted_index]

In [None]:
struct VerifyInfo
    index::Int
    time::Float64
    adv_found::Bool
    adv_not_found::Bool
    verified::Bool
    verified_res::Int
end

In [None]:
function robust_verify_mnist77()
    misclassified = 0
    result = VerifyInfo[]
    round_error_class=Int[]
    binary_file_path = "../../../resized_images/resized_mnist_images77_test.bin"
    image_size = (7, 7)
    num_images = 10000
    images, labels = load_all_binary_images(binary_file_path, image_size, num_images)
    n1 = MIPVerify.get_example_network_params("F16MNISTinput_77")
    for index in 1:5622
        total_time = 0 
        try
            ##参数读取部分代码
            sample_image = reshape(images[index], (1, 7, 7, 1))
            sample_image = Float16.(sample_image)
            sample_label = labels[index]
            predicted_output = sample_image |> n1
            #predicted_output64 = Float64.(sample_image) |> n1
            predicted_output_index = argmax(predicted_output)
            #predicted_output_index64 = argmax(predicted_output64)
            # @assert(predicted_output_index == predicted_output_index64)
            # if(predicted_output_index!=predicted_output_index64)
            #     push!(round_error_class,index)
            # end
            #println("$sample_label ----》$predicted_output_index")
            if predicted_output_index != sample_label + 1
                #println("Sample $index misclassified.")
                misclassified = misclassified+1
                continue
            end
            label_index = [1,2,3,4,5,6,7,8,9,10]    
            target_index = [i for i in label_index if i!=predicted_output_index]
            @assert (predicted_output_index - 1) == sample_label
            total_time = @elapsed begin
                d = MIPVerify.robustness_checking_verification(
                    n1, 
                    sample_image, 
                    target_index, 
                    Gurobi.Optimizer,
                    #OutputFlag=0, #prevents any output from being printed out
                    Dict("OutputFlag" => 0),
                    pp = MIPVerify.LInfNormBoundedPerturbationFamily(0.15),
                    norm_order = Inf,
                    tightening_algorithm = lp,
                    radius=0.05
                )
            end
                # solve_time = JuMP.solve_time(d[:Model])
                perturbed_sample_image = Float16.(value.(d[:PerturbedInput]))
                perturbation_predicted_output = perturbed_sample_image |> n1
                perturbation_predicted_index = argmax(perturbation_predicted_output)

            if(perturbation_predicted_index==predicted_output_index)
                #发现虚假反例
                false_adv_res=VerifyInfo(index,total_time,false,false,false,0)
                push!(result,false_adv_res)
            else
                #发现真实反例
                true_adv_res=VerifyInfo(index,total_time,true,false,true,-1)
                push!(result,true_adv_res)
            end


            # if(perturbation_predicted_index==predicted_output_index)
            #     push!(false_adv,index)
            # end
            #@assert(perturbation_predicted_index!=predicted_output_index)
            #println("predicted index: $predicted_output_index --> perturbation_index: $perturbation_predicted_index")
        catch e
            #println("处理样本 $index 时出错: $e")
            verified_robust = VerifyInfo(index,total_time,false,true,true,1)
            push!(result,verified_robust)
        end
    end
    run_statistics = Dict(
        :result => result,
        :misclassified => misclassified,
        :round_error_class => round_error_class
    )
    return run_statistics
end


In [None]:
return77=robust_verify_mnist77()

In [None]:
res77=return77[:result]

In [None]:
# 创建一个空字典
statistics_dict = Dict{String, Int}()

# 统计 tp_adv
tp_adv = count(x -> x.adv_found == true, res77)
statistics_dict["tp_adv"] = tp_adv

# 统计 tf_adv
tf_adv = count(x -> x.adv_found == false, res77)
statistics_dict["tf_adv"] = tf_adv

# 统计 verified_t
verified_t = count(x -> x.verified == true && x.verified_res == 1, res77)
statistics_dict["verified_t"] = verified_t

# 统计 verified_f
verified_f = count(x -> x.verified == true && x.verified_res == -1, res77)
statistics_dict["verified_f"] = verified_f

verified_unknow = count(x -> x.verified == false && x.verified_res == 0, res77)
statistics_dict["verified_unknow"] = verified_unknow 

statistics_dict["total_sample"] = 5622-return77[:misclassified]

# 打印字典
println(statistics_dict)


In [None]:
statistics_dict

In [None]:
##mipverify由于受到浮点舍入误差而
wrong_verified_index_by_mipverify=[i.index for i in res77 if i.verified==false && i.verified_res == 0]

In [None]:
right_verified_index_by_mipverify=[i.index for i in res77 if i.verified==true && i.verified_res==1]

In [None]:
# Open the file in write mode
open("./robustness_check_res/MP_77_verified_t_epsilon0.05.txt", "w") do file
    # Write each index to the file, one per line
    for index in right_verified_index_by_mipverify
        write(file, "$index\n")
    end
end
