Skip to content

Commit

Permalink
Merge pull request #132 from SebastianGuadalupe/master
Browse files Browse the repository at this point in the history
Add Ai2 solver using boxes
  • Loading branch information
tomerarnon committed Jul 28, 2020
2 parents c2fe59d + c237daf commit 86c5a17
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 5 deletions.
3 changes: 2 additions & 1 deletion src/NeuralVerification.jl
Expand Up @@ -75,7 +75,8 @@ include("reachability/exactReach.jl")
include("reachability/maxSens.jl")
include("reachability/ai2.jl")
include("reachability/ai2z.jl")
export ExactReach, MaxSens, Ai2, Ai2z
include("reachability/box.jl")
export ExactReach, MaxSens, Ai2, Ai2z, Box

include("satisfiability/bab.jl")
include("satisfiability/sherlock.jl")
Expand Down
37 changes: 37 additions & 0 deletions src/reachability/box.jl
@@ -0,0 +1,37 @@
"""
Box <: Solver
Box performs over-approximated reachability analysis to compute the over-approximated output reachable set for a network.
# Problem requirement
1. Network: any depth, ReLU activation (more activations to be supported in the future)
2. Input: Hyperrectangle
3. Output: Hyperrectangle
# Return
`ReachabilityResult`
# Method
Reachability analysis using using boxes.
# Property
Sound but not complete.
"""
struct Box <: Solver end

function solve(solver::Box, problem::Problem)
reach = forward_network(solver, problem.network, problem.input)
return check_inclusion(reach, problem.output)
end

forward_layer(solver::Box, layer::Layer, inputs::Vector{<:LazySet}) = forward_layer.(solver, layer, inputs)

function forward_layer(solver::Box, layer::Layer, input::AbstractPolytope)
return forward_layer(solver, layer, overapproximate(input, Hyperrectangle))
end

function forward_layer(solver::Box, layer::Layer, input::Hyperrectangle)
outlinear = overapproximate(AffineMap(layer.weights, input, layer.bias), Hyperrectangle)
relued_subsets = forward_partition(layer.activation, outlinear)
return relued_subsets
end
7 changes: 6 additions & 1 deletion src/reachability/utils/reachability.jl
Expand Up @@ -29,7 +29,7 @@ function check_inclusion(reach::P, output) where P<:LazySet
end

# return a vector so that append! is consistent with the relu forward_partition
forward_partition(act::Id, input::HPolytope) = [input]
forward_partition(act::Id, input::AbstractPolytope) = [input]

forward_partition(act::Id, input::Zonotope) = input

Expand Down Expand Up @@ -62,3 +62,8 @@ end
function forward_partition(act::ReLU, input::Zonotope)
return overapproximate(Rectification(input), Zonotope)
end

# for Hyperrectangles
function forward_partition(act::ReLU, input::Hyperrectangle)
return rectify(input)
end
2 changes: 1 addition & 1 deletion test/identity_network.jl
Expand Up @@ -18,7 +18,7 @@
problem_holds = Problem(small_nnet, in_hpoly, convert(HPolytope, out_superset))
problem_violated = Problem(small_nnet, in_hpoly, convert(HPolytope, out_overlapping))

for solver in [MaxSens(resolution = 0.6), ExactReach(), Ai2(), Ai2z()]
for solver in [MaxSens(resolution = 0.6), ExactReach(), Ai2(), Ai2z(), Box()]
holds = solve(solver, problem_holds)
violated = solve(solver, problem_violated)

Expand Down
2 changes: 1 addition & 1 deletion test/inactive_relus.jl
Expand Up @@ -18,7 +18,7 @@
problem_holds = Problem(small_nnet, in_hpoly, convert(HPolytope, out_superset))
problem_violated = Problem(small_nnet, in_hpoly, convert(HPolytope, out_overlapping))

for solver in [MaxSens(resolution = 0.6), ExactReach(), Ai2(), Ai2z()]
for solver in [MaxSens(resolution = 0.6), ExactReach(), Ai2(), Ai2z(), Box()]
holds = solve(solver, problem_holds)
violated = solve(solver, problem_violated)

Expand Down
2 changes: 1 addition & 1 deletion test/relu_network.jl
Expand Up @@ -18,7 +18,7 @@
problem_holds = Problem(small_nnet, in_hpoly, convert(HPolytope, out_superset))
problem_violated = Problem(small_nnet, in_hpoly, convert(HPolytope, out_overlapping))

for solver in [MaxSens(resolution = 0.6), ExactReach(), Ai2(), Ai2z()]
for solver in [MaxSens(resolution = 0.6), ExactReach(), Ai2(), Ai2z(), Box()]
holds = solve(solver, problem_holds)
violated = solve(solver, problem_violated)

Expand Down

0 comments on commit 86c5a17

Please sign in to comment.