In [2]:
using ProgressMeter
using Random
using PyPlot
using Base.Threads

include("utils.jl");


In [3]:
Nh = 20
initial_guess = randn(2*Nh+1)
NTemps           = 100
Nexchanges       = 100
TMax             = 10
λ                = 0.92
NSteps           = 10
StepSize         = 0.1
D                = size(initial_guess, 1)

temperatures     = zeros(NTemps)
temperatures[1]  = TMax
[temperatures[i] = temperatures[i-1]*λ for i in 2:NTemps]

θo = repeat(initial_guess, 1,NTemps)

θo = θo .+ StepSize*randn(D, NTemps)

function compute_energy(θ)
    params = θ
    b = params[end]
    c = params[1:Nh]
    w = params[Nh+1:end-1]
    return energy(0.5f0, b, c, w)
end

params = θo[:,1]
b = params[end]
c = params[1:Nh]
w = params[Nh+1:end-1]
println("size w", size(w))
println("size c", size(c))
println("size b", size(b))
println("energy", compute_energy(params))

size w(20,)
size c(20,)
size b()
energy32.44535931565889


In [33]:
function parallel_tempering(initial_guess, Nh)
    NTemps           = 10
    Nexchanges       = 10
    TMax             = 10
    λ                = 0.92
    NSteps           = 10
    StepSize         = 0.1
    D                = size(initial_guess, 1)

    if D != 2*Nh+1 # for each hidden, there is 1 weight and 1 bias, plus the visible bias
        println("Nh = ", Nh)
        println("D = ", D)
        println("2*Nh+1 = ", 2*Nh+1)
        println("initial_guess = ", initial_guess)
        println("initial_guess size = ", size(initial_guess))
        error("D must be equal to 2*Nh+1")
    end

    temperatures     = zeros(NTemps)
    temperatures[1]  = TMax
    [temperatures[i] = temperatures[i-1]*λ for i in 2:NTemps]

    θo = repeat(initial_guess, 1,NTemps)

    θo = θo .+ StepSize*randn(D, NTemps)
    
    eo = zeros(NTemps)
    en = zeros(NTemps)

    for i in 1:NTemps
        eo[i] = compute_energy(θo[:,i])
    end
    
    EMIN = minimum(eo)
    EMINidx = findmin(eo)
    EMINidx = EMINidx[2]
    θn = zeros(size(θo))

    @showprogress for exch_step in 1:Nexchanges+1
        for step in 1:NSteps
            # Sampling step
            θn = θo .+ StepSize*randn(D, NTemps)
            Threads.@threads for i in 1:NTemps
                en[i] = compute_energy(θn[:,i])
            end

            # accept-reject
            Threads.@threads for i in 1:NTemps
                if en[i] < eo[i]
                    θo[:,i] = θn[:,i]
                    eo[i] = en[i]
                elseif rand() < exp((eo[i]-en[i])/temperatures[i])
                    θo[:,i] = θn[:,i]
                    eo[i] = en[i]
                end
            end

            # find Best
            EMIN = minimum(eo)
            EMINidx = findmin(eo)
            EMINidx = EMINidx[2]
        end

        
        # Exchange step
        for i in 1:NTemps-1 
            # exchange for temperature i and i-1
            ΔE_exchange = (eo[i] - eo[i+1]) * (1/temperatures[i] - 1/temperatures[i+1])
            if ΔE_exchange < 0
                θo[i], θo[i+1] = θo[i+1], θo[i]
                eo[i], eo[i+1] = eo[i+1], eo[i]
            elseif rand() < exp(-ΔE_exchange)
                θo[i], θo[i+1] = θo[i+1], θo[i]
                eo[i], eo[i+1] = eo[i+1], eo[i]
            end
        end
    end

    return θn, eo, en
end

parallel_tempering (generic function with 1 method)

In [34]:
a,b,c = parallel_tempering(initial_guess, Nh)

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:16[39m


([-0.5105623583473824 -0.6449481687500453 … 0.34185283413298007 -0.09080670433451181; 2.724139898315955 0.5928634632500733 … -1.8309938775722534 0.059252898861878206; … ; 1.9560544745403552 1.8410276667478855 … 0.8715473714508725 1.1694456893663054; 2.094056415713504 1.5581002814217482 … 0.6651945694761399 0.4542526527692042], [-17.82720923819844, 42.404877999703444, 33.248077655008956, 32.1003460583891, 40.69700035137006, 28.448159726097245, 24.46622409422132, -1.4265194754143575, 17.958611292068436, 40.16656184526454], [43.095048224123495, -17.82720923819844, 42.404877999703444, 33.248077655008956, 32.1003460583891, 40.69700035137006, 28.448159726097245, 24.46622409422132, 37.52696871973337, 40.16656184526454])