Skip to content

Commit

Permalink
rename interval -> domain
Browse files Browse the repository at this point in the history
  • Loading branch information
tomerarnon committed Aug 19, 2020
1 parent a702b1f commit 706c42d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 22 deletions.
24 changes: 12 additions & 12 deletions src/adversarial/neurify.jl
Expand Up @@ -52,10 +52,10 @@ function solve(solver::Neurify, problem::Problem)

reach, max_violation_con, splits = select!(reach_list, solver.tree_search)

intervals = constraint_refinement(solver, problem.network, reach, max_violation_con, splits)
subdomains = constraint_refinement(solver, problem.network, reach, max_violation_con, splits)

for interval in intervals
reach = forward_network(solver, problem.network, interval)
for domain in subdomains
reach = forward_network(solver, problem.network, domain)
result, max_violation_con = check_inclusion(solver, last(reach).sym, problem.output, problem.network)
if result.status == :violated
return result
Expand All @@ -76,8 +76,8 @@ function check_inclusion(solver::Neurify, reach::SymbolicInterval,
model = Model(solver)
set_silent(model)

x = @variable(model, [1:dim(reach.interval)])
add_set_constraint!(model, reach.interval, x)
x = @variable(model, [1:dim(reach.domain)])
add_set_constraint!(model, reach.domain, x)

max_violation = 0.0
max_violation_con = nothing
Expand All @@ -104,7 +104,7 @@ function check_inclusion(solver::Neurify, reach::SymbolicInterval,
# NOTE This entire else branch should be eliminated for the paper version
else
# NOTE Is this even valid if the problem isn't solved optimally?
if value(x) reach.interval
if value(x) reach.domain
error("Not OPTIMAL, but x in the input set.\n
This is usually caused by open input set.\n
Please check your input constraints.")
Expand Down Expand Up @@ -137,7 +137,7 @@ function constraint_refinement(solver::Neurify,
# custom intersection function that doesn't do constraint pruning
= (set, lc) -> HPolytope([constraints_list(set); lc])

subsets = [reach[1].sym.interval] # all the reaches have the same domain
subsets = [reach[1].sym.domain] # all the reaches have the same domain

# If either of the normal vectors is the 0-vector, we must skip it.
# It cannot be used to create a halfspace constraint.
Expand Down Expand Up @@ -222,14 +222,14 @@ function forward_linear(solver::Neurify, input::SymbolicIntervalGradient, layer:
output_Low, output_Up = interval_map(layer.weights, input.sym.Low, input.sym.Up)
output_Up[:, end] += layer.bias
output_Low[:, end] += layer.bias
sym = SymbolicInterval(output_Low, output_Up, input.sym.interval)
sym = SymbolicInterval(output_Low, output_Up, input.sym.domain)
return SymbolicIntervalGradient(sym, input.LΛ, input.UΛ)
end

# Symbolic forward_act
function forward_act(solver::Neurify, input::SymbolicIntervalGradient, layer::Layer{ReLU})

interval = input.sym.interval
domain = input.sym.domain
Low, Up = input.sym.Low, input.sym.Up
n_node = n_nodes(layer)

Expand All @@ -243,8 +243,8 @@ function forward_act(solver::Neurify, input::SymbolicIntervalGradient, layer::La
# These are direct views into the rows of the parent arrays.
lowᵢⱼ, upᵢⱼ, out_lowᵢⱼ, out_upᵢⱼ = @views Low[j, :], Up[j, :], output_Low[j, :], output_Up[j, :]

up_low, up_up = bounds(upᵢⱼ, interval)
low_low, low_up = bounds(lowᵢⱼ, interval)
up_low, up_up = bounds(upᵢⱼ, domain)
low_low, low_up = bounds(lowᵢⱼ, domain)

up_slope = act_gradient(up_low, up_up)
low_slope = act_gradient(low_low, low_up)
Expand All @@ -256,7 +256,7 @@ function forward_act(solver::Neurify, input::SymbolicIntervalGradient, layer::La

LΛᵢ[j], UΛᵢ[j] = low_slope, up_slope
end
sym = SymbolicInterval(output_Low, output_Up, interval)
sym = SymbolicInterval(output_Low, output_Up, domain)
= push!(input.LΛ, LΛᵢ)
= push!(input.UΛ, UΛᵢ)
return SymbolicIntervalGradient(sym, LΛ, UΛ)
Expand Down
20 changes: 10 additions & 10 deletions src/adversarial/reluVal.jl
Expand Up @@ -35,7 +35,7 @@ end
struct SymbolicInterval{F<:AbstractPolytope}
Low::Matrix{Float64}
Up::Matrix{Float64}
interval::F
domain::F
end
# Data to be passed during forward_layer
struct SymbolicIntervalGradient{F<:AbstractPolytope, N<:Real}
Expand Down Expand Up @@ -94,8 +94,8 @@ end

function bisect_interval_by_max_smear(nnet::Network, reach::SymbolicIntervalMask)
LG, UG = get_gradient_bounds(nnet, reach.LΛ, reach.UΛ)
feature, monotone = get_max_smear_index(nnet, reach.sym.interval, LG, UG) #monotonicity not used in this implementation.
return collect(split_interval(reach.sym.interval, feature))
feature, monotone = get_max_smear_index(nnet, reach.sym.domain, LG, UG) #monotonicity not used in this implementation.
return collect(split_interval(reach.sym.domain, feature))
end

function select!(reach_list, tree_search)
Expand All @@ -110,8 +110,8 @@ function select!(reach_list, tree_search)
end

function symbol_to_concrete(reach::SymbolicInterval{<:Hyperrectangle})
lower = [lower_bound(l, reach.interval) for l in eachrow(reach.Low)]
upper = [upper_bound(u, reach.interval) for u in eachrow(reach.Up)]
lower = [lower_bound(l, reach.domain) for l in eachrow(reach.Low)]
upper = [upper_bound(u, reach.domain) for u in eachrow(reach.Up)]
return Hyperrectangle(low = lower, high = upper)
end

Expand All @@ -121,7 +121,7 @@ function check_inclusion(reach::SymbolicInterval{<:Hyperrectangle}, output, nnet
issubset(reachable, output) && return CounterExampleResult(:holds)

# Sample the middle point
middle_point = center(reach.interval)
middle_point = center(reach.domain)
y = compute_output(nnet, middle_point)
y output || return CounterExampleResult(:violated, middle_point)

Expand All @@ -138,14 +138,14 @@ function forward_linear(solver::ReluVal, input::SymbolicIntervalMask, layer::Lay
output_Low, output_Up = interval_map(W, input.sym.Low, input.sym.Up)
output_Up[:, end] += b
output_Low[:, end] += b
sym = SymbolicInterval(output_Low, output_Up, input.sym.interval)
sym = SymbolicInterval(output_Low, output_Up, input.sym.domain)
return SymbolicIntervalGradient(sym, input.LΛ, input.UΛ)
end

# Symbolic forward_act
function forward_act(::ReluVal, input::SymbolicIntervalMask, layer::Layer{ReLU})

interval = input.sym.interval
interval = input.sym.domain
Low, Up = input.sym.Low, input.sym.Up

n_node = n_nodes(layer)
Expand Down Expand Up @@ -217,6 +217,6 @@ bounds(v, domain) = (lower_bound(v, domain), upper_bound(v, domain))
# a node in the network. Equivalent to the upper-upper
# bound minus the lower-lower bound
function radius(sym::SymbolicInterval, j::Integer)
upper_bound(@view(sym.Up[j, :]), sym.interval) -
lower_bound(@view(sym.Low[j, :]), sym.interval)
upper_bound(@view(sym.Up[j, :]), sym.domain) -
lower_bound(@view(sym.Low[j, :]), sym.domain)
end

0 comments on commit 706c42d

Please sign in to comment.