In [2]:
using ReachabilityAnalysis
using Plots
using ControlSystemsBase
using OffsetArrays
using LinearAlgebra
using Polyhedra
using JLD2

## System Model

In [2]:
A = let
    D = [-1.0 -4.0 0.0 0.0 0.0;
         4.0 -1.0 0.0 0.0 0.0;
         0.0 0.0 -3.0 1.0 0.0;
         0.0 0.0 -1.0 -3.0 0.0;
         0.0 0.0 0.0 0.0 -2.0]
    P = [0.6 -0.1 0.1 0.7 -0.2;
         -0.5 0.7 -0.1 -0.8 0.0;
         0.9 -0.5 0.3 -0.6 0.1;
         0.5 -0.7 0.5 0.6 0.3;
         0.8 0.7 0.6 -0.3 0.2]
    P * D * inv(P)
end
ctrl_delay = 0.1
Φ = ℯ^(A * ctrl_delay)
x0 = Zonotope(fill(10., 5), collect(1.0 * I(5)))

Zonotope{Float64, Vector{Float64}, Matrix{Float64}}([10.0, 10.0, 10.0, 10.0, 10.0], [1.0 0.0 … 0.0 0.0; 0.0 1.0 … 0.0 0.0; … ; 0.0 0.0 … 1.0 0.0; 0.0 0.0 … 0.0 1.0])

In [3]:
"""
	reach(A, x0, W, H; max_order=Inf, reduced_order=2, remove_redundant=true)

Compute reachable sets for the dynamics ``x[k+1] = A x[k] + w``, where ``w`` is a noise term bounded by `W`.  The initial state is `x0`, and the time horizon is `H`.

If `max_order` is given, we reduce order of the reachable set to `reduced_order` when it exceeds this limit.  If `remove_redundant` is true, redundant generators are removed at each step.
"""
function reach(A::AbstractMatrix, x0::AbstractZonotope, W::AbstractZonotope, H::Integer; max_order::Real=Inf, reduced_order::Real=2, remove_redundant::Bool=true)
	# Preallocate x vector
	x = OffsetArray(fill(x0, H+1), OffsetArrays.Origin(0))

	for k = 1:H
		x[k] = minkowski_sum(linear_map(A, x[k-1]), W)
		if remove_redundant
			x[k] = remove_redundant_generators(x[k])
		end
		if order(x[k]) > max_order
			x[k] = reduce_order(x[k], reduced_order)
		end
	end
	
	F = Flowpipe([ReachSet(x_k, k) for (k, x_k) in enumerate(x)])
end

reach

In [6]:
# Accuracy (top-1/top-5) in percentage, parameter in millions, FLOP in billions
efficient_net_surfaces = (
    B0 = (acc1 = 77.1, acc5 = 93.3, para = 5.3, flop = 0.39),
    B1 = (acc1 = 79.1, acc5 = 94.4, para = 7.8, flop = 0.70),
    B2 = (acc1 = 80.1, acc5 = 94.9, para = 9.2, flop = 1.0),
    B3 = (acc1 = 81.6, acc5 = 95.7, para = 12, flop = 1.8),
    B4 = (acc1 = 82.9, acc5 = 96.4, para = 19, flop = 4.2),
    B5 = (acc1 = 83.6, acc5 = 96.7, para = 30, flop = 9.9),
    B6 = (acc1 = 84.0, acc5 = 96.8, para = 43, flop = 19),
    B7 = (acc1 = 84.3, acc5 = 97.0, para = 66, flop = 37),
)

efficient_net_map_full = vcat(([100/b.acc1 - 1;; 100/b.acc5 - 1;; b.flop] for b in efficient_net_surfaces)...)

8×3 Matrix{Float64}:
 0.297017  0.0718114   0.39
 0.264223  0.059322    0.7
 0.248439  0.0537408   1.0
 0.22549   0.0449321   1.8
 0.206273  0.0373444   4.2
 0.196172  0.0341262   9.9
 0.190476  0.0330579  19.0
 0.18624   0.0309278  37.0

## Exhaustive Search

In [11]:
exhaustive_search(tradeoffmap, Φ, x0, all=true) = let
	upto = size(tradeoffmap, 1)
	points = [(0,0,0,0,0); Inf; Inf]
	for idx in Iterators.product(fill(axes(tradeoffmap, 1), size(Φ, 1))...)
		all || reduce(|, map(x -> x == upto, idx)) || continue
		W = Zonotope(zeros(axes(Φ, 1)), diagm(tradeoffmap[collect(idx), 2]))
		r = reach(Φ, x0, W, 100)
		md = maximum([diameter(x.X) for x in r])
		points = hcat(points, [idx; md; sum(tradeoffmap[collect(idx), 3])])
	end
	points[:,2:end]
end

exhaustive_search (generic function with 2 methods)

In [12]:
@info 2
@time exhaustive_search(efficient_net_map_full[1:2,:], Φ, x0)

┌ Info: 2
└ @ Main /Users/jerry/Projects/date24-asd/src/reachability.ipynb:1


  3.919731 seconds (140.40 M allocations: 12.612 GiB, 15.67% gc time, 5.03% compilation time)


3×32 Matrix{Any}:
   (1, 1, 1, 1, 1)    (2, 1, 1, 1, 1)  …    (2, 2, 2, 2, 2)
 22.6995            22.5662               21.9199
  1.95               2.26                  3.5

In [13]:
@info 5
@time results = [exhaustive_search(efficient_net_map_full[1:5,:], Φ, x0)]
for i in 6:8
    @info i
    @time currentpoints = exhaustive_search(efficient_net_map_full[1:i,:], Φ, x0, false)
    push!(results, currentpoints)
end
jldsave("../data/allsave.jld2"; results)

┌ Info: 5
└ @ Main /Users/jerry/Projects/date24-asd/src/reachability.ipynb:1


410.465575 seconds (13.66 G allocations: 1.201 TiB, 22.66% gc time, 0.12% compilation time)


┌ Info: 6
└ @ Main /Users/jerry/Projects/date24-asd/src/reachability.ipynb:4


681.572011 seconds (20.33 G allocations: 1.787 TiB, 30.30% gc time, 0.03% compilation time)


┌ Info: 7
└ @ Main /Users/jerry/Projects/date24-asd/src/reachability.ipynb:4


1336.319225 seconds (39.48 G allocations: 3.470 TiB, 31.20% gc time)


┌ Info: 8
└ @ Main /Users/jerry/Projects/date24-asd/src/reachability.ipynb:4


2256.696941 seconds (69.77 G allocations: 6.135 TiB, 28.93% gc time)


In [19]:
size(results[4])

(3, 15961)

In [12]:
datt = load("../data/exhaustive_search.jld2")["points"][:,2:end]

4×32768 Matrix{Any}:
      (1, 1, 1, 1, 1)       (2, 1, 1, 1, 1)  …      (8, 8, 8, 8, 8)
    22.6995               22.5662                 20.5743
     1.95                  2.26                  185.0
 false                 false                    true

In [3]:
allsave = load("../data/allsave.jld2")["results"]

4-element Vector{Matrix{Any}}:
 [(1, 1, 1, 1, 1) (2, 1, 1, 1, 1) … (4, 5, 5, 5, 5) (5, 5, 5, 5, 5); 22.69954610804413 22.566218417560833 … 20.90536961128859 20.85146728999061; 1.9500000000000002 2.2600000000000002 … 18.599999999999998 21.0]
 [(6, 1, 1, 1, 1) (6, 2, 1, 1, 1) … (5, 6, 6, 6, 6) (6, 6, 6, 6, 6); 22.297244185156814 22.040120397888074 … 20.73533341045882 20.712471309048453; 11.460000000000003 11.770000000000001 … 43.8 49.5]
 [(7, 1, 1, 1, 1) (7, 2, 1, 1, 1) … (6, 7, 7, 7, 7) (7, 7, 7, 7, 7); 22.285839600383724 22.028715813114985 … 20.673919979892702 20.66633076992045; 20.560000000000002 20.87 … 85.9 95.0]
 [(8, 1, 1, 1, 1) (8, 2, 1, 1, 1) … (7, 8, 8, 8, 8) (8, 8, 8, 8, 8); 22.263100974660716 22.005977187391977 … 20.589466573438585 20.574335097060782; 38.56 38.870000000000005 … 167.0 185.0]

In [29]:
# sort(hcat(allsave...), dims=2)[:,773]
[sum(size.(allsave, 2)[1:x]) for x in 1:4]

4-element Vector{Int64}:
  3125
  7776
 16807
 32768

In [20]:
sort(datt[1:3,:], dims=2)[:,773]

3-element Vector{Any}:
   (1, 2, 5, 1, 5)
 20.72113019142722
  5.6899999999999995