# Simplified Spatial Pooler

Here I am trying to get used to Julia by implementing Spatial Pooler w/o any fancy features. Just plain compute with learning.

In [5]:
struct Sds
    shape::Tuple{Int, Vararg{Int}}
    size::Int
    sparsity::Float64
    active_size::Int
    
    Sds(size::Int, sparsity::Float64) = new(
        (size,), size, sparsity, 
        round(Int, size * sparsity)
    )
    
    Sds(size::Int, active_size::Int) = new(
        (size,), size, active_size / size, active_size
    )
end

In [6]:
@assert Sds(100, .3).active_size == 30
@assert Sds(80, 10).sparsity == 0.125

In [7]:
sds = Sds(20, 4)

Sds((20,), 20, 0.2, 4)

In [8]:
using Random

seed = 42
rng = MersenneTwister(seed)

MersenneTwister(42)

In [9]:
using Pkg
Pkg.add("StatsBase")
Pkg.add("Distributions")
Pkg.add("BenchmarkTools")

[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.8/Project.toml`
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.8/Manifest.toml`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.8/Project.toml`
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.8/Manifest.toml`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.8/Project.toml`
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.8/Manifest.toml`


In [10]:
using StatsBase
using Distributions
using BenchmarkTools

In [11]:
seed = 42
rng = MersenneTwister(seed)

sample(rng, 1:100, 10, replace=false, ordered=true)

10-element Vector{Int64}:
  6
 14
 47
 59
 60
 61
 71
 84
 85
 94

In [12]:
@benchmark sample(rng, 1:100, 10, replace=false, ordered=true)

BenchmarkTools.Trial: 10000 samples with 326 evaluations.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m278.500 ns[22m[39m … [35m 1.397 μs[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 0.00%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m283.230 ns              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m285.923 ns[22m[39m ± [32m35.848 ns[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m0.37% ± 2.48%

  [39m [39m [39m [39m [39m [39m [39m▁[39m█[39m▇[39m▇[39m▇[34m▂[39m[39m▂[39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▁[39m▁[39m▁[39m▂[3

In [13]:
function rand_sparse(rng::AbstractRNG, sds::Sds)::Vector{Int}
    return sample(rng, 1:sds.size, sds.active_size, replace=false, ordered=true)
end

function sparse_to_dense(sparse_sdr::Vector{Int}, sds::Sds)::Vector{Bool}
    a = zeros(Bool, sds.size)
    a[sparse_sdr] .= 1
    return a
end

rand_sparse(rng, sds), sparse_to_dense(rand_sparse(rng, sds), sds)

([5, 8, 10, 16], Bool[0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0])

In [14]:
sds_in, sds_out = Sds(20, 5), Sds(30, 4)
perm_threshold = 0.5

weights = rand(sds_in.size, sds_out.size)
synapses = weights .>= perm_threshold

synapses

20×30 BitMatrix:
 0  1  1  1  0  0  1  0  0  1  1  1  1  …  1  0  0  1  1  1  0  0  1  0  0  0
 0  0  1  0  0  0  1  1  0  0  0  1  0     1  0  0  1  1  0  1  0  0  0  0  1
 1  1  0  0  1  0  0  1  1  1  1  0  1     0  1  1  1  0  1  1  0  1  1  1  1
 1  0  1  0  0  1  1  1  1  1  0  1  0     1  0  0  0  1  0  0  0  0  1  0  0
 0  0  0  1  0  1  0  0  0  1  1  1  0     1  0  0  0  1  1  1  0  0  1  1  0
 1  1  1  0  1  1  0  0  0  0  1  1  1  …  0  0  1  0  1  1  0  1  0  0  0  1
 1  0  0  0  1  0  1  0  0  0  1  0  0     1  1  1  1  0  0  0  1  1  0  1  1
 1  0  1  0  1  0  0  1  1  0  0  0  1     0  0  1  0  0  0  0  0  0  1  1  0
 1  0  1  1  0  0  0  0  0  0  0  0  0     1  0  1  1  1  0  0  1  1  1  0  0
 1  0  1  0  0  0  1  0  0  1  0  0  0     1  1  1  1  1  0  0  1  1  0  0  1
 0  0  1  0  1  0  1  1  1  0  0  1  0  …  0  1  0  1  1  0  1  0  0  0  1  1
 0  0  0  1  1  0  0  0  0  1  0  1  0     1  0  0  0  0  1  0  1  0  0  0  0
 0  1  1  1  1  1  1  0  1  0  0  0  1     1  0

In [15]:
input_sdr = rand_sparse(rng, sds_in)

synapses[input_sdr, :]

5×30 BitMatrix:
 1  1  0  0  1  0  0  1  1  1  1  0  1  …  0  1  1  1  0  1  1  0  1  1  1  1
 0  0  0  1  0  1  0  0  0  1  1  1  0     1  0  0  0  1  1  1  0  0  1  1  0
 1  0  0  0  1  0  1  0  0  0  1  0  0     1  1  1  1  0  0  0  1  1  0  1  1
 1  1  1  1  1  1  1  0  0  0  1  1  0     0  0  0  1  0  1  1  0  1  1  0  0
 1  1  0  0  0  0  1  1  1  1  1  1  1     1  0  1  0  0  0  0  1  0  0  1  0

In [16]:
overlaps = dropdims(sum(synapses[input_sdr, :], dims=1), dims=1)

30-element Vector{Int64}:
 4
 3
 1
 2
 3
 2
 3
 2
 2
 3
 5
 3
 2
 ⋮
 3
 2
 3
 3
 1
 3
 3
 2
 3
 3
 4
 2

In [17]:
winners = partialsortperm(overlaps, 1:sds_out.active_size, rev=true)

4-element view(::Vector{Int64}, 1:4) with eltype Int64:
 11
  1
 14
 29

In [18]:
overlaps[winners]

4-element Vector{Int64}:
 5
 4
 4
 4

In [25]:
function compute_overlaps(input_sdr::Vector{Int}, synapses::BitMatrix, sds_out)::Vector{Int}
    overlaps = dropdims(sum(synapses[input_sdr, :], dims=1), dims=1)
    return overlaps
end

function compute_winners(overlaps::Vector{Int}, sds_out)::Vector{Int}
    winners = partialsortperm(overlaps, 1:sds_out.active_size, rev=true)
    return winners
end

function update_permanence(weights, winners, (perm_inc, perm_dec))
    weights[:, winners] .-= perm_dec
    weights[input_sdr, winners] .+= perm_inc

    clamp!(weights[:, winners], 0., 1.)
end

update_permanence (generic function with 1 method)

In [26]:
@benchmark compute_winners(compute_overlaps(rand_sparse(rng, sds_in), synapses, sds_out), sds_out)

BenchmarkTools.Trial: 10000 samples with 10 evaluations.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m1.083 μs[22m[39m … [35m168.312 μs[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 98.86%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m1.217 μs               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m1.266 μs[22m[39m ± [32m  2.361 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m2.63% ±  1.40%

  [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▁[39m▃[39m▅[39m▆[39m█[39m█[39m█[34m▇[39m[39m▅[39m▅[39m▃[39m▂[39m▁[32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▂[39m▂[39m▂[39m▂[39m▃[39m

In [33]:
@benchmark begin

perm_inc, perm_dec = 0.1, 0.01

input_sdr = rand_sparse(rng, sds_in)
for _ in 1:10
    overlaps = compute_overlaps(input_sdr, $synapses, sds_out)
    winners = compute_winners(overlaps, sds_out)
    update_permanence($weights, winners, (perm_inc, perm_dec))
    $synapses = weights .>= perm_threshold
end
end

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m22.791 μs[22m[39m … [35m 2.342 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 98.24%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m26.916 μs              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m29.836 μs[22m[39m ± [32m79.385 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m9.52% ±  3.53%

  [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▂[39m▂[39m [39m [39m [39m [39m [39m▁[39m▁[39m▅[39m▇[34m▇[39m[39m█[39m▇[39m▄[39m▅[39m▃[39m▂[39m [39m [39m [39m [39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▂[39m▂[39m▂[39m▃[39m▄[39m▄