Skip to content

Commit

Permalink
Merge f494748 into 50e27ea
Browse files Browse the repository at this point in the history
  • Loading branch information
castrong committed Jul 24, 2020
2 parents 50e27ea + f494748 commit 4534e89
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions src/optimization/convDual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ function get_bounds(nnet::Network, input::Vector{Float64}, ϵ::Float64)
u = Vector{Vector{Float64}}()
γ = Vector{Vector{Float64}}()
μ = Vector{Vector{Vector{Float64}}}()
input_ReLU = Vector{Vector{Float64}}()

v1 = layers[1].weights'
push!(γ, layers[1].bias)
Expand All @@ -94,19 +95,23 @@ function get_bounds(nnet::Network, input::Vector{Float64}, ϵ::Float64)
n_input = length(layers[i-1].bias)
n_output = length(layers[i].bias)

input_ReLU = relaxed_ReLU.(last(l), last(u))
D = Diagonal(input_ReLU) # a matrix whose diagonal values are the relaxed_ReLU values (maybe should be sparse?)
last_input_ReLU = relaxed_ReLU.(last(l), last(u))
push!(input_ReLU, last_input_ReLU)
D = Diagonal(last_input_ReLU) # a matrix whose diagonal values are the relaxed_ReLU values (maybe should be sparse?)

# Propagate existing terms
WD = layers[i].weights*D
v1 = v1 * WD' # TODO CHECK
map!(g -> WD*g, γ, γ)

# Updating ν_j for all previous layers
for M in μ
map!(m -> WD*m, M, M)
end

# New terms
push!(γ, layers[i].bias)
push!(μ, new_μ(n_input, n_output, input_ReLU, WD))
push!(μ, new_μ(n_input, n_output, last_input_ReLU, WD))

# Compute bounds
ψ = v1' * input + sum(γ)
Expand All @@ -127,7 +132,7 @@ function all_neg_pos_sums(slopes, l, μ, n_output)
# Need to debug
for (i, ℓ) in enumerate(l) # ℓ::Vector{Float64}
for (j, M) in enumerate(μ[i]) # M::Vector{Float64}
if 0 < slopes[j] < 1 # if in the triangle region of relaxed ReLU
if 0 < slopes[i][j] < 1 # if in the triangle region of relaxed ReLU
#posind = M .> 0
neg .+= ℓ[j] * min.(M, 0) #-M .* !posind # multiply by boolean to set the undesired values to 0.0
pos .+= ℓ[j] * max.(M, 0) #M .* posind
Expand Down

0 comments on commit 4534e89

Please sign in to comment.