In [None]:
using StatsBase
using Distributions

In [8]:
using TensorFlow
using MLDataUtils

In [9]:
using DensityEstimationML
using Plots

In [10]:
immutable NeuralDensityEstimator
    sess::Session
    
    #Network nodes
    optimizer::Tensor
    t::Tensor
    pdf::Tensor
    cdf::Tensor
end

In [112]:
a = Vector{Int}(6)
a[:]=2
a

6-element Array{Int64,1}:
 2
 2
 2
 2
 2
 2

In [194]:
leaky_relu6(z) = 0.001z + nn.relu6(z)

function NeuralDensityEstimator(prob_layer_sizes)
    sess = Session(Graph())
    @tf begin
        t1 = placeholder(Float32, shape=[1, -1])
        t2 = placeholder(Float32, shape=[1, -1])
        countprob_between = placeholder(Float32, shape=[-1])
        
        zp1 = [t1]
        zp2 = [t2]
        layer_sizes= [1; prob_layer_sizes; 1]
        act_funs = Vector{Function}(length(layer_sizes)-1)
        act_funs[:] = nn.relu6
        act_funs[end] = nn.sigmoid
        for ii in 2:length(layer_sizes)
            below_size = layer_sizes[ii-1]
            above_size = layer_sizes[ii]
            
            act_fun = act_funs[ii-1]
            
            Wii = get_variable("W_$ii", [above_size, below_size], Float32)
            bii = get_variable("b_$ii", [above_size, 1], Float32)
            
            push!(zp1, act_fun(Wii*zp1[end] .+ bii))
            push!(zp2, act_fun(Wii*zp2[end] .+ bii))
        end
        
        cdf1 = zp1[end]
        cdf2 = zp2[end]
        #Assumes t2>t1
        cdf_between = cdf2 - cdf1
                
        
        losses = 0.5.*(cdf_between .- countprob_between).^2
        
        loss=reduce_mean(losses; axis=2)
        optimizer = train.minimize(train.AdamOptimizer(), loss)
        
        pdf1 = gradients(cdf1, t1)
    end
    
    run(sess, global_variables_initializer())
    
    NeuralDensityEstimator(sess, optimizer, t1, pdf1, cdf1)
end

NeuralDensityEstimator

In [195]:
function Distributions.cdf(est::NeuralDensityEstimator, t::Real)
    gr = est.sess.graph
    ts = reshape([t], (1,1))
    run(est.sess, est.cdf, Dict(est.t=>ts))[1]
end

function Distributions.pdf(est::NeuralDensityEstimator, t::Real)
    gr = est.sess.graph
    ts = reshape([t], (1,1))
    run(est.sess, est.pdf, Dict(est.t=>ts))[1]
end



In [196]:
"""
Returns 3 vectors.
A vector of start points \$t1\$
A vector of end points \$t2\$
A vector of the counts of elements \$ti\$ between them, such that \$\left{t_i \mid t1 \le ti < t2 \right}\$
"""
function get_cdf_training_pairs(observations)
    observations = sort(observations)
    
    t1s = Float32[]
    t2s = Float32[]
    probs = Float32[]
    for (ii, t1) in enumerate(observations)
        for jj in ii+1 : length(observations)
            t2 = observations[jj]
            count = jj - ii           
            push!(t1s, t1)
            push!(t2s, t2)
            push!(probs, count/length(observations))
        end
    end
    
    t1s, t2s, probs
end



get_cdf_training_pairs

In [197]:
function StatsBase.fit!(estimator::NeuralDensityEstimator, observations;
    epochs = 20,
    batch_size = 1024)
    gr = estimator.sess.graph
    
    for ii in 1:epochs
        batch_losses = Float32[]
        for (t1_o, t2_o, probs_o) = eachbatch(shuffleobs(get_cdf_training_pairs(observations)), batch_size)
            loss_o, _, = run(estimator.sess, 
                [gr["loss"], estimator.optimizer],
                Dict(gr["t1"]=>t1_o', gr["t2"]=>t2_o', gr["countprob_between"]=>probs_o))
            push!(batch_losses, loss_o[1])
        end
        epoch_loss = mean(batch_losses)
        println("Epoch $ii: loss: $(epoch_loss)")
    end
    estimator
end

In [205]:
est = NeuralDensityEstimator([32, 64])
data = GenerateDatasets.magdon_ismail_and_atiya();


2017-08-07 17:56:49.634704: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1030] Creating TensorFlow device (/gpu:0) -> (device: 0, name: GeForce GTX TITAN X, pci bus id: 0000:01:00.0)
Stacktrace:
 [1] [1mdepwarn[22m[22m[1m([22m[22m::String, ::Symbol[1m)[22m[22m at [1m./deprecated.jl:70[22m[22m
 [2] [1m(::Base.##716#717)[22m[22m[1m([22m[22m::Float64, ::TensorFlow.Tensor{Float32}[1m)[22m[22m at [1m./deprecated.jl:346[22m[22m
 [3] [1m(::TensorFlow.###8#9#11{Base.##716#717})[22m[22m[1m([22m[22m::Array{Any,1}, ::Function, ::Float64, ::Vararg{Any,N} where N[1m)[22m[22m at [1m/home/uniwa/students2/students/20361362/linux/.julia/v0.6/TensorFlow/src/meta.jl:67[22m[22m
 [4] [1mNeuralDensityEstimator[22m[22m[1m([22m[22m::Array{Int64,1}[1m)[22m[22m at [1m./In[194]:35[22m[22m
 [5] [1minclude_string[22m[22m[1m([22m[22m::String, ::String[1m)[22m[22m at [1m./loading.jl:515[22m[22m
 [6] [1minclude_string[22m[22m[1m([22m[22m::Module,

In [206]:
fit!(est, data; epochs=100)

Epoch 1: loss: 0.08006751


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 2: loss: 0.046977956
Epoch 3: loss: 0.0085209375


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 4: loss: 0.005304319
Epoch 5: loss: 0.003396785


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 6: loss: 0.0023797657
Epoch 7: loss: 0.0021810306


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 8: loss: 0.0021199351
Epoch 9: loss: 0.0020639251


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 10: loss: 0.0020287037
Epoch 11: loss: 0.0019947905


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 12: loss: 0.0019686772
Epoch 13: loss: 0.001954392
Epoch 14: loss: 0.0019356569
Epoch 15: loss: 0.0019233845


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 16: loss: 0.001920972
Epoch 17: loss: 0.0019144812


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 18: loss: 0.0019133193
Epoch 19: loss: 0.0019064002


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 20: loss: 0.0019004102
Epoch 21: loss: 0.0018975632


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 22: loss: 0.0019001849
Epoch 23: loss: 0.0018915924


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 24: loss: 0.0018957583
Epoch 25: loss: 0.0018953802


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 26: loss: 0.0018967061
Epoch 27: loss: 0.0018996648


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 28: loss: 0.0018994547
Epoch 29: loss: 0.0018970398


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 30: loss: 0.0018952544
Epoch 31: loss: 0.0018969594


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 32: loss: 0.0018964917
Epoch 33: loss: 0.0018968823


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 34: loss: 0.0018922492
Epoch 35: loss: 0.0018949332


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 36: loss: 0.0018930074
Epoch 37: loss: 0.0018852577


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 38: loss: 0.0018969058
Epoch 39: loss: 0.0018968921


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 40: loss: 0.0018948369
Epoch 41: loss: 0.0018962547


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 42: loss: 0.0018947016
Epoch 43: loss: 0.0018949046


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 44: loss: 0.0018906072
Epoch 45: loss: 0.0019010892


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 46: loss: 0.0018929573
Epoch 47: loss: 0.0018980795


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 48: loss: 0.0018937298
Epoch 49: loss: 0.0018904986


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 50: loss: 0.0018953552
Epoch 51: loss: 0.0018937041


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 52: loss: 0.0018965353
Epoch 53: loss: 0.0018881798


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 54: loss: 0.0018956487
Epoch 55: loss: 0.0018909003


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 56: loss: 0.001896214
Epoch 57: loss: 0.0018936355


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 58: loss: 0.001889624
Epoch 59: loss: 0.001886987


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 60: loss: 0.0018843632
Epoch 61: loss: 0.0018803881


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 62: loss: 0.0018805065
Epoch 63: loss: 0.0018829508


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 64: loss: 0.0018837481
Epoch 65: loss: 0.0018784645


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 66: loss: 0.0018778426
Epoch 67: loss: 0.0018761975


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 68: loss: 0.0018766035
Epoch 69: loss: 0.0018787356


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 70: loss: 0.0018769632
Epoch 71: loss: 0.0018788837


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 72: loss: 0.0018797147
Epoch 73: loss: 0.0018731903


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 74: loss: 0.0018689993
Epoch 75: loss: 0.0018671984


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 76: loss: 0.0018683085
Epoch 77: loss: 0.0018628627


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 78: loss: 0.0018674076
Epoch 79: loss: 0.0018705545


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 80: loss: 0.0018604919
Epoch 81: loss: 0.0018704052


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 82: loss: 0.0018631951
Epoch 83: loss: 0.0018622926


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 84: loss: 0.0018582602
Epoch 85: loss: 0.0018535227


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 86: loss: 0.0018598153
Epoch 87: loss: 0.0018553563


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 88: loss: 0.0018641935
Epoch 89: loss: 0.001864215


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 90: loss: 0.0018523722
Epoch 91: loss: 0.0018499993


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 92: loss: 0.0018488502
Epoch 93: loss: 0.0018423517


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 94: loss: 0.0018513656
Epoch 95: loss: 0.0018494421


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 96: loss: 0.0018489873
Epoch 97: loss: 0.0018489037


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 98: loss: 0.0018465866
Epoch 99: loss: 0.0018522037


[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m[1m[36mINFO: [39m[22m[36mThe specified values for size and/or count will result in 444 unused data points
[39m

Epoch 100: loss: 0.0018517583


NeuralDensityEstimator(Session(Ptr{Void} @0x00007f8510d5deb0), <Tensor Group:1 shape=unknown dtype=Any>, <Tensor t1:1 shape=(1, ?) dtype=Float32>, <Tensor gradients/MatMul_grad/MatMul_3:1 shape=(1, ?) dtype=Float32>, <Tensor Sigmoid:1 shape=(1, ?) dtype=Float32>)

In [207]:
histogram(data)

In [208]:
X=-50:0.1:100

plot(X, pdf.(est, X))

In [209]:
empirical_cdf(data, X) = [length(filter(i->i<x, data)) for x in X]./length(data)

empirical_cdf (generic function with 1 method)

In [210]:
plot(X, [cdf.(est, X), empirical_cdf(data, X)] , label=["Estimated" "Empirical"])