Skip to content

Commit

Permalink
Merge 1a3c300 into 86c5a17
Browse files Browse the repository at this point in the history
  • Loading branch information
tomerarnon committed Jul 28, 2020
2 parents 86c5a17 + 1a3c300 commit 0fc66e6
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 130 deletions.
4 changes: 1 addition & 3 deletions src/NeuralVerification.jl
Expand Up @@ -74,9 +74,7 @@ include("reachability/utils/reachability.jl")
include("reachability/exactReach.jl")
include("reachability/maxSens.jl")
include("reachability/ai2.jl")
include("reachability/ai2z.jl")
include("reachability/box.jl")
export ExactReach, MaxSens, Ai2, Ai2z, Box
export ExactReach, MaxSens, Ai2, Ai2h, Ai2z, Box

include("satisfiability/bab.jl")
include("satisfiability/sherlock.jl")
Expand Down
78 changes: 60 additions & 18 deletions src/reachability/ai2.jl
@@ -1,11 +1,25 @@
"""
Ai2
Ai2{T}
Ai2 performs over-approximated reachability analysis to compute the over-approximated output reachable set for a network.
`Ai2` performs over-approximated reachability analysis to compute the over-approximated
output reachable set for a network. `T` can be `Hyperrectangle`, `Zonotope`, or
`HPolytope`, and determines the amount of over-approximation (and hence also performance
tradeoff). The original implementation (from [1]) uses Zonotopes, so we consider this
the "benchmark" case. The `HPolytope` case is more precise, but slower, and the opposite
is true of the `Hyperrectangle` case.
Note that initializing `Ai2()` defaults to `Ai2{Zonotope}`.
The following aliases also exist for convenience:
```julia
const Ai2h = Ai2{HPolytope}
const Ai2z = Ai2{Zonotope}
const Box = Ai2{Hyperrectangle}
```
# Problem requirement
1. Network: any depth, ReLU activation (more activations to be supported in the future)
2. Input: HPolytope
2. Input: AbstractPolytope
3. Output: AbstractPolytope
# Return
Expand All @@ -18,30 +32,58 @@ Reachability analysis using split and join.
Sound but not complete.
# Reference
T. Gehr, M. Mirman, D. Drashsler-Cohen, P. Tsankov, S. Chaudhuri, and M. Vechev,
[1] T. Gehr, M. Mirman, D. Drashsler-Cohen, P. Tsankov, S. Chaudhuri, and M. Vechev,
"Ai2: Safety and Robustness Certification of Neural Networks with Abstract Interpretation,"
in *2018 IEEE Symposium on Security and Privacy (SP)*, 2018.
## Note
Efficient over-approximation of intersections and unions involving zonotopes relies on Theorem 3.1 of
[2] Singh, G., Gehr, T., Mirman, M., Püschel, M., & Vechev, M. (2018). Fast
and effective robustness certification. In Advances in Neural Information
Processing Systems (pp. 10802-10813).
"""
struct Ai2 <: Solver end
struct Ai2{T<:Union{Hyperrectangle, Zonotope, HPolytope}} <: Solver end

Ai2() = Ai2{Zonotope}()
const Ai2h = Ai2{HPolytope}
const Ai2z = Ai2{Zonotope}
const Box = Ai2{Hyperrectangle}

function solve(solver::Ai2, problem::Problem)
reach = forward_network(solver, problem.network, problem.input)
return check_inclusion(reach, problem.output)
end

forward_layer(solver::Ai2, layer::Layer, inputs::Vector{<:AbstractPolytope}) = forward_layer.(solver, layer, inputs)
forward_layer(solver::Ai2, L::Layer, inputs::Vector) = forward_layer.(solver, L, inputs)

function forward_layer(solver::Ai2h, L::Layer{ReLU}, input::AbstractPolytope)
= affine_map(L, input)
relued_subsets = forward_partition(L.activation, Ẑ) # defined in reachability.jl
return convex_hull(UnionSetArray(relued_subsets))
end

# method for Zonotope and Hyperrectangle, if the input set isn't a Zonotope
function forward_layer(solver::Union{Ai2z, Box}, L::Layer{ReLU}, input::AbstractPolytope)
return forward_layer(solver, L, overapproximate(input, Hyperrectangle))
end

function forward_layer(solver::Ai2z, L::Layer{ReLU}, input::AbstractZonotope)
= affine_map(L, input)
return overapproximate(Rectification(Ẑ), Zonotope)
end


function forward_layer(solver::Ai2, layer::Layer, input::AbstractPolytope)
outlinear = affine_map(layer, input)
relued_subsets = forward_partition(layer.activation, outlinear) # defined in ExactReach
return convex_hull(relued_subsets)
function forward_layer(solver::Box, L::Layer{ReLU}, input::AbstractZonotope)
= approximate_affine_map(L, input)
return rectify(Ẑ)
end

# extend lazysets convex_hull to a vector of polytopes
function LazySets.convex_hull(sets::Vector{<:AbstractPolytope}; backend = CDDLib.Library())
hull = first(sets)
for P in sets
hull = convex_hull(hull, P, backend = backend)
end
return hull
end
function forward_layer(solver::Ai2, L::Layer{Id}, input)
return affine_map(L, input)
end


function convex_hull(U::UnionSetArray{<:Any, <:HPolytope})
tohrep(VPolytope(LazySets.convex_hull(U)))
end
48 changes: 0 additions & 48 deletions src/reachability/ai2z.jl

This file was deleted.

37 changes: 0 additions & 37 deletions src/reachability/box.jl

This file was deleted.

2 changes: 1 addition & 1 deletion src/reachability/exactReach.jl
Expand Up @@ -31,7 +31,7 @@ end

forward_layer(solver::ExactReach, layer::Layer, input) = forward_layer(solver, layer, convert(HPolytope, input))

function forward_layer(solver::ExactReach, layer::Layer, input::Vector{HPolytope})
function forward_layer(solver::ExactReach, layer::Layer, input::Vector{<:HPolytope})
output = Vector{HPolytope}(undef, 0)
for i in 1:length(input)
input[i] = affine_map(layer, input[i])
Expand Down
25 changes: 5 additions & 20 deletions src/reachability/utils/reachability.jl
Expand Up @@ -18,24 +18,19 @@ function check_inclusion(reach::Vector{<:LazySet}, output)
for poly in reach
issubset(poly, output) || return ReachabilityResult(:violated, reach)
end
return ReachabilityResult(:holds, similar(reach, 0))
return ReachabilityResult(:holds, reach)
end

function check_inclusion(reach::P, output) where P<:LazySet
if issubset(reach, output)
return ReachabilityResult(:holds, P[])
end
return ReachabilityResult(:violated, [reach])
return ReachabilityResult(issubset(reach, output) ? :holds : :violated, [reach])
end

# return a vector so that append! is consistent with the relu forward_partition
forward_partition(act::Id, input::AbstractPolytope) = [input]

forward_partition(act::Id, input::Zonotope) = input
forward_partition(act::Id, input) = [input]

function forward_partition(act::ReLU, input::HPolytope)
n = dim(input)
output = Vector{HPolytope}(undef, 0)
output = Vector{HPolytope{Float64}}(undef, 0)
C, d = tosimplehrep(input)
dh = [d; zeros(n)]
for h in 0:(2^n)-1
Expand All @@ -56,14 +51,4 @@ function getP(h::Int64, n::Int64)
vec[i] = ifelse(str[i] == '1', 1, 0)
end
return Diagonal(vec)
end

# forward_partition for Zonotopes
function forward_partition(act::ReLU, input::Zonotope)
return overapproximate(Rectification(input), Zonotope)
end

# for Hyperrectangles
function forward_partition(act::ReLU, input::Hyperrectangle)
return rectify(input)
end
end
2 changes: 1 addition & 1 deletion test/identity_network.jl
Expand Up @@ -18,7 +18,7 @@
problem_holds = Problem(small_nnet, in_hpoly, convert(HPolytope, out_superset))
problem_violated = Problem(small_nnet, in_hpoly, convert(HPolytope, out_overlapping))

for solver in [MaxSens(resolution = 0.6), ExactReach(), Ai2(), Ai2z(), Box()]
for solver in [MaxSens(resolution = 0.6), ExactReach(), Ai2(), Ai2h(), Box()]
holds = solve(solver, problem_holds)
violated = solve(solver, problem_violated)

Expand Down
2 changes: 1 addition & 1 deletion test/inactive_relus.jl
Expand Up @@ -18,7 +18,7 @@
problem_holds = Problem(small_nnet, in_hpoly, convert(HPolytope, out_superset))
problem_violated = Problem(small_nnet, in_hpoly, convert(HPolytope, out_overlapping))

for solver in [MaxSens(resolution = 0.6), ExactReach(), Ai2(), Ai2z(), Box()]
for solver in [MaxSens(resolution = 0.6), ExactReach(), Ai2(), Ai2h(), Box()]
holds = solve(solver, problem_holds)
violated = solve(solver, problem_violated)

Expand Down
2 changes: 1 addition & 1 deletion test/relu_network.jl
Expand Up @@ -18,7 +18,7 @@
problem_holds = Problem(small_nnet, in_hpoly, convert(HPolytope, out_superset))
problem_violated = Problem(small_nnet, in_hpoly, convert(HPolytope, out_overlapping))

for solver in [MaxSens(resolution = 0.6), ExactReach(), Ai2(), Ai2z(), Box()]
for solver in [MaxSens(resolution = 0.6), ExactReach(), Ai2(), Ai2h(), Box()]
holds = solve(solver, problem_holds)
violated = solve(solver, problem_violated)

Expand Down

0 comments on commit 0fc66e6

Please sign in to comment.