# Inference in Bayesian networks

We can implement exact inference using factors. Recall that factors represent discrete multivariate distributions. We use the following three operations on factors to achieve this:

#### 1. Factor product
We use the factor product to combine two factors to produce a larger factor whose scope is the combined scope of the input factors. If we have $φ(X, Y)$ and $ψ(Y, Z)$, then $φ · ψ$ will be over $X$, $Y$, and $Z$ with $(φ · ψ)(x, y, z) = φ(x, y)ψ(y, z)$. 

#### 2. Factor marginalization
We use factor marginalization to sum out a particular variable from the entire factor table, removing it from the resulting scope.

#### 3. Factor conditioning
We use factor conditioning with respect to some evidence to remove any rows in the table inconsistent with that evidence.

### Let's implement the code from last notebook (`1. Representations`) so we can reuse it

In [8]:
struct Variable
    name::Symbol
    m::Int
end

const Assignment = Dict{Symbol, Int}
const FactorTable = Dict{Assignment, Float64}

struct Factor
    vars::Vector{Variable}
    table::FactorTable
end

Base.Dict{Symbol, T}(a::NamedTuple) where T = 
    Dict{Symbol, T}( k=>v for (k,v) in zip(keys(a), values(a)) )

Base.convert(::Type{Dict{Symbol, T}}, a::NamedTuple) where T = 
    Dict{Symbol, T}(a)

Base.isequal(a::Dict{Symbol, T}, b::NamedTuple) where T = 
    length(a) == length(b) &&
    all(a[k] == v for (k,v) in zip(keys(b), values(b)))

variablenames(ϕ::Factor) = [var.name for var in ϕ.vars]

select(a::Assignment, varnames::Vector{Symbol}) = 
    Assignment( n => a[n] for n in varnames )

select (generic function with 1 method)

In [7]:
import Base.Iterators: product

function assignments(vars::AbstractVector{Variable})
    names = [var.name for var in vars]
    matrix_of_assignments = [Assignment(name => p for (name, p) in zip(names, p)) 
        for p in product((1:var.m for var in vars)...)]
    vec(matrix_of_assignments)
end

assignments (generic function with 1 method)

In [21]:
using LightGraphs

function normalize!(ϕ::Factor)
    z = sum(p for (a,p) in ϕ.table)
    for (a,p) in ϕ.table
        ϕ.table[a] = p/z
    end
    ϕ
end

struct BayesianNetwork
    vars::Vector{Variable}
    factors::Vector{Factor}
    graph::SimpleDiGraph{Int64}
end

In [10]:
?merge

search: [0m[1mm[22m[0m[1me[22m[0m[1mr[22m[0m[1mg[22m[0m[1me[22m [0m[1mm[22m[0m[1me[22m[0m[1mr[22m[0m[1mg[22m[0m[1me[22m! [0m[1mm[22m[0m[1me[22m[0m[1mr[22m[0m[1mg[22m[0m[1me[22mwith [0m[1mm[22m[0m[1me[22m[0m[1mr[22m[0m[1mg[22m[0m[1me[22mwith! [0m[1mM[22m[0m[1me[22m[0m[1mr[22m[0m[1mg[22m[0m[1me[22mSort



```
merge(d::AbstractDict, others::AbstractDict...)
```

Construct a merged collection from the given collections. If necessary, the types of the resulting collection will be promoted to accommodate the types of the merged collections. If the same key is present in another collection, the value for that key will be the value it has in the last collection listed. See also [`mergewith`](@ref) for custom handling of values with the same key.

# Examples

```jldoctest
julia> a = Dict("foo" => 0.0, "bar" => 42.0)
Dict{String, Float64} with 2 entries:
  "bar" => 42.0
  "foo" => 0.0

julia> b = Dict("baz" => 17, "bar" => 4711)
Dict{String, Int64} with 2 entries:
  "bar" => 4711
  "baz" => 17

julia> merge(a, b)
Dict{String, Float64} with 3 entries:
  "bar" => 4711.0
  "baz" => 17.0
  "foo" => 0.0

julia> merge(b, a)
Dict{String, Float64} with 3 entries:
  "bar" => 42.0
  "baz" => 17.0
  "foo" => 0.0
```

---

```
merge(a::NamedTuple, bs::NamedTuple...)
```

Construct a new named tuple by merging two or more existing ones, in a left-associative manner. Merging proceeds left-to-right, between pairs of named tuples, and so the order of fields present in both the leftmost and rightmost named tuples take the same position as they are found in the leftmost named tuple. However, values are taken from matching fields in the rightmost named tuple that contains that field. Fields present in only the rightmost named tuple of a pair are appended at the end. A fallback is implemented for when only a single named tuple is supplied, with signature `merge(a::NamedTuple)`.

!!! compat "Julia 1.1"
    Merging 3 or more `NamedTuple` requires at least Julia 1.1.


# Examples

```jldoctest
julia> merge((a=1, b=2, c=3), (b=4, d=5))
(a = 1, b = 4, c = 3, d = 5)
```

```jldoctest
julia> merge((a=1, b=2), (b=3, c=(d=1,)), (c=(d=2,),))
(a = 1, b = 3, c = (d = 2,))
```

---

```
merge(a::NamedTuple, iterable)
```

Interpret an iterable of key-value pairs as a named tuple, and perform a merge.

```jldoctest
julia> merge((a=1, b=2, c=3), [:b=>4, :d=>5])
(a = 1, b = 4, c = 3, d = 5)
```


### 1. Factor product

In [70]:
function Base.:*(ϕ::Factor, ψ::Factor)
    ϕ_names = variablenames(ϕ)
    ψ_names = variablenames(ψ)
    ψ_only = setdiff(ψ.vars, ϕ.vars)
    final_table = FactorTable()
    for (ϕ_assignment, ϕ_prob) in ϕ.table
        for ψ_only_assignment in assignments(ψ_only)
            complete_assignment = merge(ϕ_assignment, ψ_only_assignment)
            ψ_assignment = select(complete_assignment, ψ_names)
            final_table[complete_assignment] = ϕ_prob * get(ψ.table, ψ_assignment, 0.0)
        end
    end
    vars = vcat(ϕ.vars, ψ_only)
    Factor(vars, final_table)
end

### 2. Factor marginalization

In [91]:
function marginalize(ϕ::Factor, name::Symbol)
    final_table = FactorTable()
    for (assignment, prob) in ϕ.table
        assignment_without_name = delete!(copy(assignment), name)
        final_table[assignment_without_name] = get(final_table, assignment_without_name, 0.0) + prob
    end
    final_variables = filter(v -> v.name != name, ϕ.vars)
    Factor(final_variables, final_table)
end

marginalize (generic function with 1 method)

### 3. Factor conditioning

Two methods for factor conditioning given some evidence. 

i) The first takes a factor $φ$ and returns a new factor whose table entries are consistent with the variable named `name` having value `value`. 

ii) The second takes a factor $φ$ and applies evidence in the form of a named tuple. The `in_scope` method returns `true` if a variable named `name` is within the scope of the factor $φ$.

In [92]:
in_scope(name::Symbol, ϕ::Factor) = any(name == v.name for v in ϕ.vars)

function condition(ϕ::Factor, name::Symbol, value)
    if !in_scope(name, ϕ)
        ϕ
    end
    final_table = FactorTable()
    for (assignment, prob) in ϕ.table
        if assignment[name] == value
            final_table[delete!(copy(assignment), name)] = prob
        end
    end
    final_variables = filter(v -> v.name != name, ϕ.vars)
    Factor(final_variables, final_table)
end

function condition(ϕ::Factor, evidence)
    for (name, value) in pairs(evidence)
        ϕ = condition(ϕ, name, value)
    end
    ϕ
end

condition (generic function with 2 methods)

In [16]:
?pairs

search: [0m[1mp[22m[0m[1ma[22m[0m[1mi[22m[0m[1mr[22m[0m[1ms[22m u[0m[1mp[22mperc[0m[1ma[22msef[0m[1mi[22m[0m[1mr[22m[0m[1ms[22mt [0m[1mP[22m[0m[1ma[22m[0m[1mi[22m[0m[1mr[22m [0m[1mp[22m[0m[1ma[22mrt[0m[1mi[22malso[0m[1mr[22mt [0m[1mp[22m[0m[1ma[22mrt[0m[1mi[22malso[0m[1mr[22mt! [0m[1mp[22m[0m[1ma[22mrt[0m[1mi[22malso[0m[1mr[22mtperm



```
pairs(collection)
```

Return an iterator over `key => value` pairs for any collection that maps a set of keys to a set of values. This includes arrays, where the keys are the array indices.

---

```
pairs(IndexLinear(), A)
pairs(IndexCartesian(), A)
pairs(IndexStyle(A), A)
```

An iterator that accesses each element of the array `A`, returning `i => x`, where `i` is the index for the element and `x = A[i]`. Identical to `pairs(A)`, except that the style of index can be selected. Also similar to `enumerate(A)`, except `i` will be a valid index for `A`, while `enumerate` always counts from 1 regardless of the indices of `A`.

Specifying [`IndexLinear()`](@ref) ensures that `i` will be an integer; specifying [`IndexCartesian()`](@ref) ensures that `i` will be a [`CartesianIndex`](@ref); specifying `IndexStyle(A)` chooses whichever has been defined as the native indexing style for array `A`.

Mutation of the bounds of the underlying array will invalidate this iterator.

# Examples

```jldoctest
julia> A = ["a" "d"; "b" "e"; "c" "f"];

julia> for (index, value) in pairs(IndexStyle(A), A)
           println("$index $value")
       end
1 a
2 b
3 c
4 d
5 e
6 f

julia> S = view(A, 1:2, :);

julia> for (index, value) in pairs(IndexStyle(S), S)
           println("$index $value")
       end
CartesianIndex(1, 1) a
CartesianIndex(2, 1) b
CartesianIndex(1, 2) d
CartesianIndex(2, 2) e
```

See also: [`IndexStyle`](@ref), [`axes`](@ref).


### Exact inference

A naive exact inference algorithm for a discrete Bayesian network `bn`, which takes as input a set of query variable names `query`, and `evidence` associating values with observed variables. 

The algorithm computes a joint distribution over the query variables in the form of a factor. 

> We introduce the `ExactInference` type to allow for `infer` to be called with different inference methods, as shall be seen in the rest of this chapter.

Notice how `ExactInference` parameter is not used in the function definition. It is solely used for multiple dispatch.

In [19]:
struct ExactInference end

In [23]:
function infer(::ExactInference, bn, query, evidence)
    ϕ = prod(bn.factors)
    ϕ = condition(ϕ, evidence)
    for name in setdiff(variablenames(ϕ), query)
        ϕ = marginalize(ϕ, name)
    end
    normalize!(ϕ)
end

infer (generic function with 1 method)

### Let's test on the example given in the book

In [27]:
X = Variable(:x, 2)
Y = Variable(:y, 2)
Z = Variable(:z, 2)

ϕ1 = Factor([X,Y], FactorTable(
        (x=1, y=1) => 0.3,
        (x=1, y=2) => 0.4,
        (x=2, y=1) => 0.2,
        (x=2, y=2) => 0.1,
        ))

ϕ2 = Factor([Y,Z], FactorTable(
        (y=1, z=1) => 0.2,
        (y=1, z=2) => 0.0,
        (y=2, z=1) => 0.3,
        (y=2, z=2) => 0.5,
        ))

Factor(Variable[Variable(:y, 2), Variable(:z, 2)], Dict(Dict(:y => 2, :z => 2) => 0.5, Dict(:y => 1, :z => 1) => 0.2, Dict(:y => 1, :z => 2) => 0.0, Dict(:y => 2, :z => 1) => 0.3))

#### Test factor product

In [75]:
result = ϕ1 * ϕ2
println(result.vars)
result.table

Variable[Variable(:x, 2), Variable(:y, 2), Variable(:z, 2)]


Dict{Dict{Symbol, Int64}, Float64} with 8 entries:
  Dict(:y=>1, :z=>1, :x=>1) => 0.06
  Dict(:y=>1, :z=>1, :x=>2) => 0.04
  Dict(:y=>1, :z=>2, :x=>1) => 0.0
  Dict(:y=>2, :z=>2, :x=>1) => 0.2
  Dict(:y=>2, :z=>1, :x=>2) => 0.03
  Dict(:y=>1, :z=>2, :x=>2) => 0.0
  Dict(:y=>2, :z=>2, :x=>2) => 0.05
  Dict(:y=>2, :z=>1, :x=>1) => 0.12

#### Test factor marginalization

In [95]:
X = Variable(:x, 2)
Y = Variable(:y, 2)
Z = Variable(:z, 2)

ϕ = Factor([X,Y,Z], FactorTable(
        (x=1, y=1, z=1) => 0.08,
        (x=1, y=1, z=2) => 0.31,
        (x=1, y=2, z=1) => 0.09,
        (x=1, y=2, z=2) => 0.37,
        (x=2, y=1, z=1) => 0.01,
        (x=2, y=1, z=2) => 0.05,
        (x=2, y=2, z=1) => 0.02,
        (x=2, y=2, z=2) => 0.07,
        ))

Factor(Variable[Variable(:x, 2), Variable(:y, 2), Variable(:z, 2)], Dict(Dict(:y => 1, :z => 1, :x => 1) => 0.08, Dict(:y => 1, :z => 1, :x => 2) => 0.01, Dict(:y => 1, :z => 2, :x => 1) => 0.31, Dict(:y => 2, :z => 2, :x => 1) => 0.37, Dict(:y => 2, :z => 1, :x => 2) => 0.02, Dict(:y => 1, :z => 2, :x => 2) => 0.05, Dict(:y => 2, :z => 2, :x => 2) => 0.07, Dict(:y => 2, :z => 1, :x => 1) => 0.09))

In [96]:
result = marginalize(ϕ, Y.name)
println(result.vars)
result.table

Variable[Variable(:x, 2), Variable(:z, 2)]


Dict{Dict{Symbol, Int64}, Float64} with 4 entries:
  Dict(:z=>2, :x=>2) => 0.12
  Dict(:z=>2, :x=>1) => 0.68
  Dict(:z=>1, :x=>1) => 0.17
  Dict(:z=>1, :x=>2) => 0.03

#### Test factor conditioning

In [98]:
result = condition(ϕ, :y, 2.0)
println(result.vars)
result.table

Variable[Variable(:x, 2), Variable(:z, 2)]


Dict{Dict{Symbol, Int64}, Float64} with 4 entries:
  Dict(:z=>2, :x=>2) => 0.07
  Dict(:z=>2, :x=>1) => 0.37
  Dict(:z=>1, :x=>1) => 0.09
  Dict(:z=>1, :x=>2) => 0.02