In [1]:
include("../src/Hebbian.jl")

using MLDatasets: MNIST
using Statistics: mean
using Random: seed!
using Flux

# Global parameters
const LATENT_SIZE = 64
const BATCH_SIZE = 128
const DTYPE = Float32


function preprocess_mnist(x)
    x = reshape(x, 28 * 28, size(x, 3))
    binarize(x)::DTYPE = x > 0.5 ? 1 : 0
    @. binarize(x)
end


function reconstruction_error(model::BoltzmannMachine, real::AbstractArray{T, 2}) where T<:Real
    recon = bernoulli_argmax(model, activate(model, real))

    N = 0
    nerror = 0
    for i = 1:size(real, 2)
        for j = 1:model.ambientsize
            N += 1
            if recon[j, i] != real[j, i]
                nerror += 1
            end
        end
    end

    nerror / N
end


function stats(x)
    mean(x), std(x)
end


# Load training data
train_x, _ = MNIST.traindata()
train_x = preprocess_mnist(train_x)
epochs = 4
num_batches = size(train_x, 2) / BATCH_SIZE * epochs |> floor |> Int32
data = Flux.Data.DataLoader(train_x, batchsize=BATCH_SIZE, shuffle=true)
m = size(train_x, 1)

# Construct topology
# 1. connect ambient and latent units
topo = Connection[]
for i = 1:m
    for j = 1:LATENT_SIZE
        push!(topo, Connection(i, m+j))
    end
end
# 2. connect latent and latent units
# for i = 1:LATENT_SIZE
#     for j = 1:LATENT_SIZE
#         push!(topo, Connection(m+i, m+j))
#     end
# end
# 3. connect ambient and ambient units
# for i = 1:m
#     for j = 1:m
#         push!(topo, Connection(i, j))
#     end
# end

model = create_boltzmann(topo, train_x; minval=-1, maxval=1)
opt = Flux.Optimise.ADAM()

function (logger::Logger)(step, real, grads)
    if (step == 1) || (step % logger.logstep == 0)
        m = model.ambientsize
        real_ambient = real[1:m, :]
        real_latent = real[(m+1):end, :]

        @show step
        @show stats(reconstruction_error(model, real))
        @show stats((real_ambient))
        @show stats(mean(real_latent))
        @show stats(mean(grads[1]))
        @show stats(mean(grads[2]))
        @show stats(mean(model.kernel))
        @show stats(mean(model.bias))
    end
end

logger = Logger(model, 1)

Logger(BoltzmannMachine{Float32}(Connection[Connection(1, 785), Connection(1, 786), Connection(1, 787), Connection(1, 788), Connection(1, 789), Connection(1, 790), Connection(1, 791), Connection(1, 792), Connection(1, 793), Connection(1, 794)  …  Connection(848, 775), Connection(848, 776), Connection(848, 777), Connection(848, 778), Connection(848, 779), Connection(848, 780), Connection(848, 781), Connection(848, 782), Connection(848, 783), Connection(848, 784)], 
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⣿⣿⣿
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⣿⣿⣿
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⣿⣿⣿
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⣿⣿⣿
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⣿⣿⣿
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⣿⣿⣿
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⣿⣿⣿
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⣿⣿⣿
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⣿⣿⣿
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⣿⣿⣿
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⣿⣿⣿
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⣿⣿⣿
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⣿⣿

In [2]:
real_ = nothing
for (step, real_ambient) in enumerate(data)
    # Compute real state (particles)
    μ = getlatent(model, real_ambient)
    real_latent = bernoulli_sample(model, μ)
    real_ = cat(real_ambient, real_latent; dims=1)
    @show std(real_latent)
    break
end
# @show mean(outer(real_, real_); dims=3)
# @show grads = hebbian_gradients(model, real_)

std(real_latent) = 1.000051f0


In [3]:
real_ = nothing
for (step, real_ambient) in enumerate(data)
    # Compute real state (particles)
    μ = getlatent(model, real_ambient)
    real_latent = bernoulli_sample(model, μ)
    real_ = cat(real_ambient, real_latent; dims=1)
    @show std(real_latent)
    break
end
# @show mean(outer(real_, real_); dims=3)
# @show grads = hebbian_gradients(model, real_)

std(real_latent) = 1.0000006f0


In [4]:
real_latent = real_[model.ambientsize+1:end, :]
std(real_latent)

1.0000006f0

In [5]:
for epoch = 1:epochs
    println("\n-- Epoch: $(epoch):")
    early_stopped = hebbian_train!(model, data, opt, 0; cb=logger)
end


-- Epoch: 1:
step = 1
stats(reconstruction_error(model, real)) = (0.9850825095663265, NaN)
stats(real_ambient) = (0.13545321f0, 0.3422087f0)
stats(mean(real_latent)) = (-0.01586914f0, NaN32)
stats(mean(grads[1])) = (0.0021446384f0, NaN32)
stats(mean(grads[2])) = (-0.124032654f0, NaN32)
stats(mean(model.kernel)) = (0.0f0, NaN32)
stats(mean(model.bias)) = (-5.176797f0, NaN32)
step = 2
stats(reconstruction_error(model, real)) = (0.9853316326530612, NaN)
stats(real_ambient) = (0.13068001f0, 0.33705166f0)
stats(mean(real_latent)) = (-0.0031738281f0, NaN32)
stats(mean(grads[1])) = (0.0009899528f0, NaN32)
stats(mean(grads[2])) = (-0.12057783f0, NaN32)
stats(mean(model.kernel)) = (-6.522002f-6, NaN32)
stats(mean(model.bias)) = (-5.1762314f0, NaN32)
step = 3
stats(reconstruction_error(model, real)) = (0.9857501594387755, NaN)
stats(real_ambient) = (0.13228436f0, 0.33880144f0)
stats(mean(real_latent)) = (-0.022705078f0, NaN32)
stats(mean(grads[1])) = (0.00318535f0, NaN32)
stats(mean(grads[2])) 