In [1]:
using MLDatasets
using Distances
using Statistics
using Clustering
using Printf
using DataStructures
using SumProductNetworks

In [2]:
train_x, train_y = MNIST.traindata()
train_x = MNIST.convert2features(train_x)
train_2 = train_x[:, train_y .== 2]

784×5958 Array{N0f8,2} with eltype FixedPointNumbers.Normed{UInt8,8}:
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0

In [3]:
println(size(train_2))
println(typeof(train_2))

(784, 5958)
Array{FixedPointNumbers.Normed{UInt8,8},2}


In [55]:
function learnspn_c(data; minclustersize=100)
    """A minimal implementation of Learn-SPN algorithm
    
    Repeat sum node and product node in turn to form binary SPNs.
    """
    q = Queue{Tuple}()
    root = FiniteSumNode()
    variables = collect(1:size(data)[1])
    instances = collect(1:size(data)[2])
    enqueue!(q, (root, variables, instances))
    
    while length(q) > 0
        node, variables, instances = dequeue!(q)
        println(node)
        println(length(variables), " ", length(instances))
        # stopping conditions
        if length(variables) == 1
            v = variables[1]
            μ_hat = mean(data[v, :])
            σ_hat = std(data[v, :]) + 0.01
            add!(node, UnivariateNode(Normal(μ_hat, σ_hat), v))
        end
        if length(instances) < minclustersize
            for v in variables
                μ_hat = mean(data[v, :])
                σ_hat = std(data[v, :]) + 0.01
                add!(node, UnivariateNode(Normal(μ_hat, σ_hat), v))
            end
        end
        # divide and conquer
        if typeof(node) <: SumNode
            clusterweights = clusterinstances(data, variables, instances)
            for (cluster, weight) in clusterweights
                prod = FiniteProductNode()
                add!(node, prod, log(weight))
                enqueue!(q, (prod, variables, cluster))
            end
        elseif typeof(node) <: ProductNode
            splits = splitvariables(data, variables, instances)
            for split in splits
                sum = FiniteSumNode()
                add!(node, sum)
                enqueue!(q, (sum, split, instances))
            end
        else
            error("encountered unknown node type")
        end
    end
    
    return SumProductNetwork(root)
end

function splitvariables(data, variables, instances)
    """Split variables into two groups
    
    Run G-test (threshold by KL-divergence from observation to expectation).
    """
    function binarize(x)
        binary_x = zeros(Int, size(x))
        binary_x[x .> mean(x)] .= 1
        return binary_x
    end
    slice = data[variables, instances]
    distances = zeros(length(variables))
    p = sum(binarize(slice[1, :]))/length(instances)
    for i in 1:length(variables)
        q = sum(binarize(slice[i, :]))/length(instances)
        e = (p + q)/2
        d = evaluate(KLDivergence(),
                     [p, (1 - p), q, (1 - q)],
                     [e, (1 - e), e, (1 - e)])
        distances[i] = d
    end
    splitone = partialsortperm(distances, 1:floor(Integer, length(variables)/2))
    splittwo = setdiff(variables, splitone)
    return (splitone, splittwo)
end

function clusterinstances(data, variables, instances)
    """Cluster instances into two groups
    
    Run K-means clustering with k=2.
    """
    slice = data[variables, instances]
    results = kmeans(slice, 2)
    clusterone = instances[results.assignments .== 1]
    clustertwo = setdiff(instances, clusterone)
    weight = length(clusterone)/length(instances)
    return ((clusterone, weight), (clustertwo, 1 - weight))
end

clusterinstances (generic function with 1 method)

In [56]:
spn = learnspn_c(train_2)

FiniteSumNode{Float64}(##sum#2529)
	weights = Float64[]
	normalized = false
	No scope set!

784 5958
FiniteProductNode(##prod#2530)
	No scope set!

784 3384
FiniteProductNode(##prod#2531)
	No scope set!

784 2574
FiniteSumNode{Float64}(##sum#2532)
	weights = Float64[]
	normalized = false
	No scope set!

392 3384
FiniteSumNode{Float64}(##sum#2533)
	weights = Float64[]
	normalized = false
	No scope set!

392 3384
FiniteSumNode{Float64}(##sum#2534)
	weights = Float64[]
	normalized = false
	No scope set!

392 2574
FiniteSumNode{Float64}(##sum#2535)
	weights = Float64[]
	normalized = false
	No scope set!

392 2574
FiniteProductNode(##prod#2536)
	No scope set!

392 3284
FiniteProductNode(##prod#2537)
	No scope set!

392 100
FiniteProductNode(##prod#2538)
	No scope set!

392 1892
FiniteProductNode(##prod#2539)
	No scope set!

392 1492
FiniteProductNode(##prod#2540)
	No scope set!

392 2532
FiniteProductNode(##prod#2541)
	No scope set!

392 42
FiniteProductNode(##prod#2542)
	No scope set!

392

MethodError: MethodError: no method matching add!(::FiniteSumNode{Float64}, ::UnivariateNode)
Closest candidates are:
  add!(::SumNode, ::SPNNode, !Matched::T<:Real) where T<:Real at /home/kouariga/Research/julia/packages/SumProductNetworks.jl/src/nodeFunctions.jl:162
  add!(!Matched::ProductNode, ::SPNNode) at /home/kouariga/Research/julia/packages/SumProductNetworks.jl/src/nodeFunctions.jl:170

In [6]:
updatescope!(spn)

UndefVarError: UndefVarError: spn not defined

In [7]:
llhvals = initllhvals(spn, train_2[:, 1])
logpdf!(spn, train_2[:, 1], llhvals)
for node in spn.leaves
    println("logpdf at $(node.id) = $(llhvals[node.id])")
end

UndefVarError: UndefVarError: spn not defined

In [8]:
logp = logpdf(spn, train_2[:, 1])

UndefVarError: UndefVarError: spn not defined

In [9]:
logp = logpdf(spn, train_x[:, 10])

UndefVarError: UndefVarError: spn not defined

In [10]:
spn.leaves

UndefVarError: UndefVarError: spn not defined

In [41]:
FiniteSumNode <: SumNode

true