diff --git a/src/NeuralVerification.jl b/src/NeuralVerification.jl index 5c003b09..ecb4c181 100644 --- a/src/NeuralVerification.jl +++ b/src/NeuralVerification.jl @@ -75,7 +75,8 @@ include("reachability/exactReach.jl") include("reachability/maxSens.jl") include("reachability/ai2.jl") include("reachability/ai2z.jl") -export ExactReach, MaxSens, Ai2, Ai2z +include("reachability/box.jl") +export ExactReach, MaxSens, Ai2, Ai2z, Box include("satisfiability/bab.jl") include("satisfiability/sherlock.jl") diff --git a/src/reachability/box.jl b/src/reachability/box.jl new file mode 100644 index 00000000..76be4d3c --- /dev/null +++ b/src/reachability/box.jl @@ -0,0 +1,37 @@ +""" + Box <: Solver + +Box performs over-approximated reachability analysis to compute the over-approximated output reachable set for a network. + +# Problem requirement +1. Network: any depth, ReLU activation (more activations to be supported in the future) +2. Input: Hyperrectangle +3. Output: Hyperrectangle + +# Return +`ReachabilityResult` + +# Method +Reachability analysis using using boxes. + +# Property +Sound but not complete. +""" +struct Box <: Solver end + +function solve(solver::Box, problem::Problem) + reach = forward_network(solver, problem.network, problem.input) + return check_inclusion(reach, problem.output) +end + +forward_layer(solver::Box, layer::Layer, inputs::Vector{<:LazySet}) = forward_layer.(solver, layer, inputs) + +function forward_layer(solver::Box, layer::Layer, input::AbstractPolytope) + return forward_layer(solver, layer, overapproximate(input, Hyperrectangle)) +end + +function forward_layer(solver::Box, layer::Layer, input::Hyperrectangle) + outlinear = overapproximate(AffineMap(layer.weights, input, layer.bias), Hyperrectangle) + relued_subsets = forward_partition(layer.activation, outlinear) + return relued_subsets +end diff --git a/src/reachability/utils/reachability.jl b/src/reachability/utils/reachability.jl index cd4eccbb..3b157e57 100644 --- a/src/reachability/utils/reachability.jl +++ b/src/reachability/utils/reachability.jl @@ -29,7 +29,7 @@ function check_inclusion(reach::P, output) where P<:LazySet end # return a vector so that append! is consistent with the relu forward_partition -forward_partition(act::Id, input::HPolytope) = [input] +forward_partition(act::Id, input::AbstractPolytope) = [input] forward_partition(act::Id, input::Zonotope) = input @@ -62,3 +62,8 @@ end function forward_partition(act::ReLU, input::Zonotope) return overapproximate(Rectification(input), Zonotope) end + +# for Hyperrectangles +function forward_partition(act::ReLU, input::Hyperrectangle) + return rectify(input) +end diff --git a/test/identity_network.jl b/test/identity_network.jl index 929bce79..13afdea8 100644 --- a/test/identity_network.jl +++ b/test/identity_network.jl @@ -18,7 +18,7 @@ problem_holds = Problem(small_nnet, in_hpoly, convert(HPolytope, out_superset)) problem_violated = Problem(small_nnet, in_hpoly, convert(HPolytope, out_overlapping)) - for solver in [MaxSens(resolution = 0.6), ExactReach(), Ai2(), Ai2z()] + for solver in [MaxSens(resolution = 0.6), ExactReach(), Ai2(), Ai2z(), Box()] holds = solve(solver, problem_holds) violated = solve(solver, problem_violated) diff --git a/test/inactive_relus.jl b/test/inactive_relus.jl index b9011f6d..6f0bc40e 100644 --- a/test/inactive_relus.jl +++ b/test/inactive_relus.jl @@ -18,7 +18,7 @@ problem_holds = Problem(small_nnet, in_hpoly, convert(HPolytope, out_superset)) problem_violated = Problem(small_nnet, in_hpoly, convert(HPolytope, out_overlapping)) - for solver in [MaxSens(resolution = 0.6), ExactReach(), Ai2(), Ai2z()] + for solver in [MaxSens(resolution = 0.6), ExactReach(), Ai2(), Ai2z(), Box()] holds = solve(solver, problem_holds) violated = solve(solver, problem_violated) diff --git a/test/relu_network.jl b/test/relu_network.jl index 43ac53b7..ddd9d7d7 100644 --- a/test/relu_network.jl +++ b/test/relu_network.jl @@ -18,7 +18,7 @@ problem_holds = Problem(small_nnet, in_hpoly, convert(HPolytope, out_superset)) problem_violated = Problem(small_nnet, in_hpoly, convert(HPolytope, out_overlapping)) - for solver in [MaxSens(resolution = 0.6), ExactReach(), Ai2(), Ai2z()] + for solver in [MaxSens(resolution = 0.6), ExactReach(), Ai2(), Ai2z(), Box()] holds = solve(solver, problem_holds) violated = solve(solver, problem_violated)