# Simplified Spatial Pooler

Here I am trying to get used to Julia by implementing Spatial Pooler w/o any fancy features. Just plain compute with learning.

## Install packages

In [None]:
using Pkg
Pkg.add("StatsBase")
Pkg.add("Distributions")
Pkg.add("BenchmarkTools")

[32m[1m    Updating[22m[39m registry at `~/.julia/registries/General.toml`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.8/Project.toml`
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.8/Manifest.toml`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.8/Project.toml`
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.8/Manifest.toml`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.8/Project.toml`
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.8/Manifest.toml`


## Common utils

In [3]:
using Random
using StatsBase
using Distributions
using BenchmarkTools


"Sparse Distributed Space parameters"
struct Sds
    shape::Tuple{Int, Vararg{Int}}
    size::Int
    sparsity::Float64
    active_size::Int
    
    Sds(size::Int, sparsity::Float64) = new(
        (size,), size, sparsity, 
        round(Int, size * sparsity)
    )
    
    Sds(size::Int, active_size::Int) = new(
        (size,), size, active_size / size, active_size
    )
end

@assert Sds(100, .3).active_size == 30
@assert Sds(80, 10).sparsity == 0.125


rand_sparse(rng::AbstractRNG, sds::Sds)::Vector{Int} = sample(rng, 1:sds.size, sds.active_size, replace=false, ordered=true)
rand_sparse(sds::Sds)::Vector{Int} = sample(1:sds.size, sds.active_size, replace=false, ordered=true)

function sparse_to_dense(sparse_sdr::Vector{Int}, sds::Sds)::Vector{Bool}
    a = zeros(Bool, sds.size)
    a[sparse_sdr] .= 1
    return a
end

@assert length(rand_sparse(Sds(100, .05))) == sum(sparse_to_dense(rand_sparse(Sds(100, .05)), Sds(100, .05)))

In [9]:
mutable struct SpatialPooler
    sds_in::Sds
    sds_out::Sds
    
    rng::AbstractRNG 
    perm_threshold::Float64
    perm_increment::Float64
    perm_decrement::Float64
    
    weights::Matrix{Float64}
    synapses::BitMatrix
    
    potentials::Vector{Float64}
    winners::Vector{Int}
end

function SpatialPooler(
        sds_in::Sds, sds_out::Sds, seed::Int, 
        perm_threshold::Float64, perm_increment::Float64, perm_decrement::Float64
    )
    rng = MersenneTwister(seed)
    
    weights = rand(rng, sds_in.size, sds_out.size)
    synapses = weights .>= perm_threshold
    
    overlaps = zeros(sds_out.size)
    winners = []
    
    return SpatialPooler(
        sds_in, sds_out, rng, perm_threshold, perm_increment, perm_decrement,
        weights, synapses, 
        overlaps, winners
    )
end

function compute_overlaps(sp::SpatialPooler, input_sdr::Vector{Int})::Vector{Int}
    overlaps = sum(sp.synapses[input_sdr, :], dims=1)
    
    # drop summed dim as it was kept
    return dropdims(overlaps, dims=1)
end

function compute_winners(sp::SpatialPooler, potentials::Vector{Float64})::Vector{Int}
    return partialsortperm(sp.potentials, 1:sp.sds_out.active_size, rev=true)
end

function update_potentials!(sp::SpatialPooler, potentials::Vector{Float64})
    sp.potentials = potentials
end

function stdp_step!(sp::SpatialPooler, input_sdr::Vector{Int}, winners::Vector{Int})
    sp.weights[:, winners] .-= sp.perm_decrement
    sp.weights[input_sdr, winners] .+= sp.perm_increment

    clamp!(sp.weights[:, winners], 0., 1.)
    sp.synapses[:, winners] = sp.weights[:, winners] .>= sp.perm_threshold
end

function compute!(sp::SpatialPooler, input_sdr::Vector{Int})::Vector{Int}
    overlaps = compute_overlaps(sp, input_sdr) * 1.0
    update_potentials!(sp, overlaps)
    
    sp.winners = compute_winners(sp, sp.potentials)
    stdp_step!(sp, input_sdr, sp.winners)
    
    return sp.winners
end

sds_in, sds_out = Sds(20, 5), Sds(30, 4)
seed = 42
rng = MersenneTwister(seed)
perm_threshold = 0.5
perm_inc, perm_dec = 0.1, 0.01

sp = SpatialPooler(
    sds_in, sds_out, seed,
    perm_threshold, perm_inc, perm_dec
)
input_sdr = rand_sparse(rng, sds_in)
compute!(sp, input_sdr)

4-element Vector{Int64}:
 14
 25
  1
  3

In [10]:
@benchmark compute!(sp, rand_sparse(rng, sds_in))

BenchmarkTools.Trial: 10000 samples with 9 evaluations.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m1.889 μs[22m[39m … [35m100.074 μs[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 96.98%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m2.125 μs               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m2.206 μs[22m[39m ± [32m  2.492 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m2.96% ±  2.56%

  [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▂[39m [39m▆[39m▂[39m▃[39m█[39m▃[39m▃[34m█[39m[39m▁[39m▁[39m▄[39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▁[39m▁[39m▁[39m▁[39m▁[39m▂