Sum-product networks in Julia.
Switch branches/tags
Nothing to show
Clone or download
Fetching latest commit…
Cannot retrieve the latest commit at this time.
Permalink
Failed to load latest commit information.
src
test
.gitignore
.travis.yml
LICENSE.md
README.md
REQUIRE

README.md

Sum-Product Networks in Julia

Build Status Coverage Status

This software package implements the tractable probabilistic model sum-product network (SPN) in Julia. The package provides a clean and modular interface for SPNs and implements various helper and utility functions to efficienty work with SPN models.

News

  • 18.10.2018 - The package is officialy registered.
  • 10.10.2018 - The package now provides more efficient logpdf routines and allows for multithreaded computations.
  • 24.09.2018 - SumProductNetworks now works under Julia 1.0.

Installation

Make sure you have Julia 1.0 running. The package can be installed using Julia's package mode. (You can enter the package mode by typing ] in the REPL.)

pkg> add SumProductNetworks

Usage

The following minimal example illustrates the use of the package.

using SumProductNetworks

# Create a root sum node.
root = FiniteSumNode();

# Add two product nodes to the root.
add!(root, FiniteProductNode(), log(0.3)); # Use a weight of 0.3
add!(root, FiniteProductNode(), log(0.7)); # Use a weight of 0.7

# Add Normal distributions to the product nodes, i.e. leaves.
for prod in children(root)
    for d in 1:2 # Assume 2-D data
        add!(prod, UnivariateNode(Normal(), d));
    end
end

# Compile the constructed network to an SPN type
spn = SumProductNetwork(root);

# Print statistics on the network.
println(spn)

# Update the scope of all nodes, i.e. propagate the scope bottom-up.
updatescope!(spn)

# Evaluate the network on some data.
x = [0.8, 1.2];
logp = logpdf(spn, x)

# Save the network to a DOT file.
export_network(spn, "mySPN.dot")

Advanced Usage

Besides the basic functionality for nodes and SPNs, this package additionally provides helper functions that are useful for more advanced use-cases. The following examples illustrates a more advanced tasks.

using SumProductNetworks
using AxisArrays

N = 100
D = 2

x = rand(N, D)

# Create a root sum node.
root = FiniteSumNode{Float32}();

# Add two product nodes to the root.
add!(root, FiniteProductNode(), Float32(log(0.3))); # Use a weight of 0.3
add!(root, FiniteProductNode(), Float32(log(0.7))); # Use a weight of 0.7

# Add Normal distributions to the product nodes, i.e. leaves.
for prod in children(root)
    for d in 1:2 # Assume 2-D data
        add!(prod, UnivariateNode(Normal(), d));
    end
end

# Compile the constructed network to an SPN type
spn = SumProductNetwork(root);

# Compute the logpdf value for every node in the SPN.
idx = Axis{:id}(collect(keys(spn)))
llhvals = AxisArray(Matrix{Float32}(undef, N, length(spn)), 1:N, idx)

# Compute logpdf values for all nodes in the network.
logpdf(spn, x; idx, llhvals)

# Print the logpdf value for each leaf.
for node in spn.leaves
    println("logpdf at $(node.id) = $(llhvals[:,node.id])")
end

# Assign observations to their most likely child under a sum node.
function assignobs!(node::SumNode, observations::Vector{Int})
    j = argmax(llhvals[observations, map(c -> c.id, children(node))], dims = 2)

    # Set observations to the node.
    setobs!(node, observations)

    # Set observations for the children of the node.
    for k in length(node)
        setobs!(node[k], observations[findall(j .== k)])
    end

    # Get the parametric type of the sum node, i.e. Float32.
    T = eltype(node)

    # Update the weights of the root.
    w = map(c -> nobs(c) / nobs(node), children(node))
    for k in 1:length(node)
        logweights(node)[k] = T(log(w[k]))
    end
end

# Call assignment function for the root.
assignobs!(spn.root, collect(1:N))

Documentation

Datatypes

The following types are implemented and supported in this package. The abstract type hierarchy is designed such that it is easy to extend the existing types and that efficient implementations using type dispatching is possible.

# Abstract type hierarchy.
SPNNode
Node <: SPNNode
Leaf <: SPNNode
SumNode{T} <: Node
ProductNode <: Node

# Node types.
FiniteSumNode() <: SumNode
FiniteProductNode() <: ProductNode
IndicatorNode(value::Int, scope::Int) <: Leaf
UnivariateNode(dist::UnivariateDistribution, scope::Int) <: Leaf
MultivariateNode(dist::MultivariateDistribution, scope::Vector{Int}) <: Leaf

To get more details on the individual node type, please use the internal documentation system of Julia.

In addition to this types, the package also provides a composite type to represent an SPN, i.e.:

SumProductNetwork(root::Node)

Structure Learning

Utility functions for structure learning are currently not implemented in this package. An additional package providing a variety of structure learning algorithms will be provided soon.

The interface for learning SPN structure is:

generate_spn(X::Matrix, algo::Symbol; params...)

Utility Functions on an SumProductNetwork

The following utility functions can be used on an instance of a SumProductNetwork.

# Get all nodes of the network in topological order.
values(spn::SumProductNetwork)

# Get the ids of all nodes in the network.
keys(spn::SumProductNetwork)

# Number of nodes in the network.
length(spn::SumProductNetwork)

# Indexing using an id.
spn[id::Symbol]

# Locally normalize an SPN.
normalize!(spn::SumProductNetwork)

# Number of free parameters in the SPN.
complexity(spn::SumProductNetwork)

# Export the SPN to a DOT file.
export_network(spn::SumProductNetwork, filename::String)

Utility Functions on Nodes

The following utility functions can be used on an instance of an SPN Node.

# Indexing an internal node returns the respective child.
node[i::Int]

# Add a child to an internal node (with or without weight).
add!(node::Node, child::SPNNode)
add!(node::Node, child::SPNNode, logw::Real)

# Remove a child from an internal node.
remove!(node::Node, childIndex::Int)

# The depth of the SPN rooted at the node.
depth(node::SPNNode)

# Get all children of a node.
children(node::Node)

# Get the number of children of node.
length(node::Node)

# Get all parents of a node.
parents(node::SPNNode)

# Has the node a weights field.
hasweights(node::Node)

# Get all weights of the node.
weights(node::Node) = exp.(logweights(node))

# Get all log weights of the node
logweights(node::Node)

# Is the SPN rooted at the node normalized?
isnormalized(node::SPNNode)

General utility functions

The following functions are general utility functions.

# Independence test by Margaritis and Thurn for discrete sets.
bmitest(X::Vector{Int}, Y::Vector{Int})

Contribute

Feel free to open a PR if you want to contribute!

References

Please consider citing any of the following publications if you use this package.

  • Martin Trapp, Tamas Madl, Robert Peharz, Franz Pernkopf, Robert Trappl: Safe Semi-Supervised Learning of Sum-Product Networks. UAI 2017. pdf bibtex
  • Martin Trapp, Robert Peharz, Marcin Skowron, Tamas Madl, Franz Pernkopf, Robert Trappl: Structure Inference in Sum-Product Networks using Infinite Sum-Product Trees. NIPS 2016 - Workshop on Practical Bayesian Nonparametrics. paper bibtex