Skip to content

Commit

Permalink
Added capability for nested temporal operators. Added new grammar and…
Browse files Browse the repository at this point in the history
… analysis for state-dependent expresssions
  • Loading branch information
ancorso committed Oct 20, 2020
1 parent 60ce9b1 commit fa23247
Show file tree
Hide file tree
Showing 8 changed files with 268 additions and 72 deletions.
36 changes: 36 additions & 0 deletions examples/policy_example.jl
@@ -0,0 +1,36 @@
using InterpretableValidation
using Distributions
using ExprRules
using Random
using POMDPModels
using POMDPs
using POMDPSimulators
using POMDPPolicies

mdp = SimpleGridWorld(size = (9,9), rewards = Dict(GWPos(5,5) => 1, GWPos(6,5) => 0), tprob = .7, discount=1)

mvts = MvTimeseriesDistribution(:x => IID(collect(1:1.), Categorical(4)))
s_fn(s) = SymbolTable(:sx => s[1], :sy => s[2])
x_fn(v) = actions(mdp)[v[:x][1]]

# Demonstrate that a certain choice of policy obtains high reward
expr = Meta.parse("[(sx .== 5) .& (sy .<= 5), x .== 1] && [(sx .<= 4) .& (sy .<= 5), x .== 4] && [(sx .>= 6) .& (sy .<= 5), x .== 3]")
p = get_policy(expr, s_fn, x_fn, mvts)
mean([simulate(RolloutSimulator(), mdp, RandomPolicy(mdp)) for i=1:1000])
mean([simulate(RolloutSimulator(), mdp, FunctionPolicy(p)) for i=1:1000])

# Setup the grammar
x_samp_dist = default_comparison_distribution(mvts)
x_comp = default_comparisons(mvts)
s_samp_dist = Dict{Symbol, Distribution}(:sx => Categorical(9), :sy => Categorical(9))
all_syms = [Symbol(".<="), Symbol(".=="), Symbol(".>=")]
s_comp = Dict(:sx => all_syms, :sy => all_syms)
g = create_policy_grammar(1, x_samp_dist, x_comp, s_samp_dist, s_comp)

# Setup the loss function
lf = policy_loss_fn(mdp, s_fn, x_fn, mvts)

# optimize
optimize(()->nothing, mvts, Npop = 1000, Niter = 30, loss = lf, grammar = g, max_depth = 5)


9 changes: 5 additions & 4 deletions src/InterpretableValidation.jl
Expand Up @@ -20,7 +20,8 @@ module InterpretableValidation
include("timeseries_distributions.jl")

# Inverse Logic
export sample_constraints, flex_not,
export sample_constraints, flex_not, not_inv,
parse_implications, parse_implications!, and_expressions,
all_before, all_before_inv, all_after, all_after_inv, all_between,
all_between_inv, any_between, any_between_inv,
and_inv, or_inv, all_inv, any_inv, bitwise_and_inv, bitwise_or_inv
Expand All @@ -33,9 +34,9 @@ module InterpretableValidation
include("constrained_distributions.jl")

# Optimization and Grammar
export loss_fn, sample_comparison, create_grammar,
default_comparison_distribution, optimize, set_global_grammar_params,
discrete_action_mdp, continuous_action_mdp, sample_history
export loss_fn, sample_comparison, create_stl_grammar, create_policy_grammar,
default_comparison_distribution, default_comparisons, optimize_stl_policy, optimize_timed_stl, set_global_grammar_params,
discrete_action_mdp, continuous_action_mdp, sample_history, get_policy, policy_loss_fn
include("optimization.jl")
end

2 changes: 1 addition & 1 deletion src/constrained_distributions.jl
Expand Up @@ -151,7 +151,7 @@ end
Base.showerror(io::IO, e::InfeasibleConstraint) = print(io, "Infeasible Constraint: ", e.msg)

# Tries to sample a time series that satisfies the expression from the provided MvTimeseriesDistribution
function Base.rand(rng::AbstractRNG, expr::Expr, d::MvTimeseriesDistribution; validity_trials = 10)
function Base.rand(rng::AbstractRNG, expr::Expr, d::MvTimeseriesDistribution; validity_trials = 2)
N = N_pts(d)
for i=1:validity_trials
d2 = MvTimeseriesDistribution()
Expand Down
95 changes: 69 additions & 26 deletions src/inverse_logic.jl
@@ -1,36 +1,39 @@
const IV_TERMINALS = [Symbol(".=="), Symbol(".<="), Symbol(".>=")] # Calls that should terminate the tree search
const IV_EXPANDERS = [:any, :all] # Calls the expand from scalar to time series
const IV_PARAMETERIZED = Dict(:all_before => 1, :all_after => 1, :all_between => 2, :any_between=>2)
const IV_PARAMETERIZED = Dict(:any => 0, :all => 0, :all_before => 1, :all_after => 1, :all_between => 2, :any_between=>2)

# Negation operator applied to bool or :anybool
flex_not(b::Union{Symbol, Bool}) = (b == :anybool ? :anybool : !b)

all_before(τ, i) = all(τ[1:i])
all_after(τ, i) = all(τ[i:end])

all_between(τ, i, j) = all(τ[min(i,j):max(i,j)])
any_between(τ, i, j) = any(τ[min(i,j):max(i,j)])
all_between(τ, i, j) = all(τ[min(i,j):min(length(τ), max(i,j))])
any_between(τ, i, j) = any(τ[min(i,j):min(length(τ), max(i,j))])

function all_between_inv(out, i, j, N, rng::AbstractRNG)
l = min(i,j)
l > N && throw(InfeasibleConstraint("Couldn't find feasible constraint due to out of bounds range"))
h = min(max(i,j), N)
arr = Array{Any}(undef, N)
fill!(arr, :anybool)
if out == true
arr[min(i,j):max(i,j)] .= true
arr[l:h] .= true
elseif out == false
arr[rand(rng, min(i,j):max(i,j))] = false
arr[rand(rng, l:h)] = false
end
(arr,)
end

function any_between_inv(out, i, j, N, rng::AbstractRNG)
l = min(i,j)
l > N && throw(InfeasibleConstraint("Couldn't find feasible constraint due to out of bounds range"))
h = min(max(i,j), N)
arr = Array{Any}(undef, N)
fill!(arr, :anybool)
if out == true
pt = rand(rng, min(i,j):max(i,j))
arr[1:pt-1] .= false
arr[pt] = true
arr[rand(rng, l:h)] = true
elseif out == false
arr[min(i,j):max(i,j)] .= false
arr[l:h] .= false
end
(arr,)
end
Expand Down Expand Up @@ -58,6 +61,11 @@ function all_after_inv(out, i, N, rng::AbstractRNG)
(arr,)
end

function not_inv(out, rng::AbstractRNG)
out == :anybool && return :anybool
(length(out) > 1 ) ? ([flex_not(v) for v in out],) : flex_not(out)
end


# Inverse of the "and" operator which takes two boolean inputs
# If output is true then both inputs are true
Expand Down Expand Up @@ -126,17 +134,19 @@ function bitwise_op_inv(out, op_inv, rng::AbstractRNG)
end

# Defines the invers of the bitwise "and" operator
bitwise_and_inv(out, rng::AbstractRNG) = bitwise_op_inv(out, and_inv, rng)
bitwise_and_inv(out, rng::AbstractRNG) = (out isa Array) ? bitwise_op_inv(out, and_inv, rng) : and_inv(out, rng)


# Defines the inverse of the bitwise "or" operator
bitwise_or_inv(out, rng::AbstractRNG) = bitwise_op_inv(out, or_inv, rng)
bitwise_or_inv(out, rng::AbstractRNG) = (out isa Array) ? bitwise_op_inv(out, or_inv, rng) : or_inv(out, rng)

# Mapping an operation to its inverse
bool_inverses = Dict(
:&& => and_inv,
:|| => or_inv,
:&& => bitwise_and_inv,
:|| => bitwise_or_inv,
:any => any_inv,
:all => all_inv,
:! => not_inv,
:all_before => all_before_inv,
:all_after => all_after_inv,
:all_between => all_between_inv,
Expand All @@ -158,6 +168,7 @@ end
# `constraints` - The list of constraints and their corresponding truth values
# `N` - The length of the time series
function sample_constraints!(expr, truthval, constraints, N, rng::AbstractRNG)
all(truthval .== :anybool) && return
if expr.head == :call && expr.args[1] in IV_TERMINALS
# Add constrain expression to the list of constraints and return
push!(constraints, [expr, truthval])
Expand All @@ -169,27 +180,59 @@ function sample_constraints!(expr, truthval, constraints, N, rng::AbstractRNG)
# Special handing for expanding operators that need to be passed `N`
op = expr.args[1]
inv = bool_inverses[op]
if op in IV_EXPANDERS
inv = inv(truthval, N, rng)
sample_constraints!(expr.args[2], inv[1], constraints, N, rng)
elseif op in keys(IV_PARAMETERIZED)
if op in keys(IV_PARAMETERIZED)
ps = IV_PARAMETERIZED[op]
inv = inv(truthval, expr.args[3:3+(ps - 1)]..., N, rng)
sample_constraints!(expr.args[2], inv[1], constraints, N, rng)
if length(truthval) == 1
inv_res = inv(truthval, expr.args[3:3+(ps - 1)]..., N, rng)
sample_constraints!(expr.args[2], inv_res[1], constraints, N, rng)
else
for i=1:N
inv_res = [fill(:anybool, i-1)..., inv(truthval[i], expr.args[3:3+(ps - 1)]..., N-i+1, rng)[1]...]
sample_constraints!(expr.args[2], inv_res, constraints, N, rng)
end
end
else
inv = inv(truthval, rng)
for i in 1:length(inv)
sample_constraints!(expr.args[i+1], inv[i], constraints, N, rng)
inv_res = inv(truthval, rng)
for i in 1:length(inv_res)
sample_constraints!(expr.args[i+1], inv_res[i], constraints, N, rng)
end
end

else
# Here "head" contains the operator and args contains the expressions
# Get the inverse and recurse directly
inv = bool_inverses[expr.head](truthval, rng)
for i in 1:length(inv)
sample_constraints!(expr.args[i], inv[i], constraints, N, rng)
inv_res = bool_inverses[expr.head](truthval, rng)
for i in 1:length(inv_res)
sample_constraints!(expr.args[i], inv_res[i], constraints, N, rng)
end
end
end


## The stuff below here is for implies clauses
function parse_implications(expr)
implications = Dict{Expr, Expr}()
parse_implications!(expr, implications)
implications
end

function parse_implications!(expr, implications)
if expr.head == :vect
implications[expr.args[1]] = expr.args[2]
return
elseif expr.head == :&&
parse_implications!(expr.args[1], implications)
parse_implications!(expr.args[2], implications)
else
throw(error("Unregonized head ", expr.head))
end
end

function and_expressions(exprs)
new_expr = exprs[1]
for i=2:length(exprs)
new_expr = Expr(:.&, new_expr, exprs[i])
end
Expr(:call, :all, new_expr)
end

26 changes: 17 additions & 9 deletions src/mvrandn.jl
Expand Up @@ -3,6 +3,8 @@
lnPhi(x) = -0.5 * x^2 - 0.69314718055994530941723212 +
log(erfcx(x / 1.4142135623730950488016887242))

sqrt2pi = sqrt(2*pi)

# Computes ln(P(a<Z<b)) where Z~N(0,1) very accurately for any 'a', 'b'
function lnNpr(a, b)
pa, pb = lnPhi(abs(a)), lnPhi(abs(b))
Expand Down Expand Up @@ -112,14 +114,17 @@ function cholperm(Σ_in, l, u)
perm, L, z = [1:d...], zeros(d,d), zeros(d,1)
for j = 1:d
pr = Inf*ones(d) # compute marginal prob.
rd = j:d # search remaining dimensions
rd = collect(j:d) # search remaining dimensions
ud = collect(1:j-1)
D = diag(Σ)
s = D[rd] .- sum(L[rd,1:j-1].^2, dims=2)
s = D[rd] .- sum(L[rd,ud].^2, dims=2)
s[s.<0] .= 1e-16
s = sqrt.(s)

tl = (l[rd] .- L[rd,1:j-1]*z[1:j-1])./s
tu = (u[rd] .- L[rd,1:j-1]*z[1:j-1])./s
L_z = L[rd,ud]*z[ud]

tl = (l[rd] .- L_z)./s
tu = (u[rd] .- L_z)./s

pr[rd] = lnNpr.(tl,tu)

Expand All @@ -142,21 +147,24 @@ function cholperm(Σ_in, l, u)
perm[jk] .= perm[kj] # keep track of permutation

# construct L sequentially via Cholesky computation
s = Σ[j,j] - sum(L[j,1:j-1].^2)
s = Σ[j,j] - sum(L[j,ud].^2)

if s<-0.01 error("Σ is not positive semi-definite") end
if s < 0 s = 1e-16 end
L[j,j] = sqrt(s)
Ld = L[j,j]

L[j+1:d,j:j] .= (Σ[j+1:d,j]-L[j+1:d,1:j-1]*L[j:j,1:j-1]')./L[j,j]
L[j+1:d,j:j] .= (Σ[j+1:d,j]-L[j+1:d,1:j-1]*L[j:j,1:j-1]')./Ld

# find mean value, z(j), of truncated normal:
tl = (l[j].-L[j:j,1:j-1]*z[1:j-1])./L[j,j]
tu = (u[j].-L[j:j,1:j-1]*z[1:j-1])./L[j,j]
L_z = L[j:j,ud]*z[ud]

tl = (l[j].-L_z)./Ld
tu = (u[j].-L_z)./Ld
w = lnNpr.(tl,tu) # aids in computing expected value of trunc. normal

@assert length(tl) == 1 && length(tu) == 1 && length(w) == 1
z[j] = (exp(-.5*tl[1]^2 - w[1]) - exp(-5*tu[1]^2. - w[1]))/sqrt(2*pi)
z[j] = (exp(-.5*tl[1]^2 - w[1]) - exp(-5*tu[1]^2. - w[1]))/sqrt2pi
end
L, l, u, perm
end
Expand Down

0 comments on commit fa23247

Please sign in to comment.