In [1]:
include("../src/PhasorNetworks.jl")
using .PhasorNetworks, Plots

[32m[1mPrecompiling[22m[39m SciMLSensitivity
[32m  ✓ [39m[90mTracker[39m
[32m  ✓ [39m[90mArrayInterface → ArrayInterfaceTrackerExt[39m
[32m  ✓ [39m[90mTracker → TrackerPDMatsExt[39m
[32m  ✓ [39m[90mZygote → ZygoteTrackerExt[39m
[32m  ✓ [39m[90mRecursiveArrayTools → RecursiveArrayToolsTrackerExt[39m
[32m  ✓ [39m[90mDiffEqBase → DiffEqBaseTrackerExt[39m
[32m  ✓ [39m[90mSimpleNonlinearSolve → SimpleNonlinearSolveTrackerExt[39m
[32m  ✓ [39m[90mNonlinearSolve[39m
[32m  ✓ [39m[90mNonlinearSolve → NonlinearSolveZygoteExt[39m
[32m  ✓ [39m[90mNonlinearSolve → NonlinearSolveNLsolveExt[39m
[32m  ✓ [39m[90mNonlinearSolve → NonlinearSolveBandedMatricesExt[39m
[32m  ✓ [39m[90mDiffEqCallbacks[39m
[32m  ✓ [39m[90mOrdinaryDiffEq[39m
[32m  ✓ [39m[90mStochasticDiffEq[39m
[32m  ✓ [39mSciMLSensitivity
  15 dependencies successfully precompiled in 102 seconds. 260 already precompiled.
[32m[1mPrecompiling[22m[39m ComponentArraysTrackerExt
[32

In [2]:
using Lux, MLUtils, MLDatasets, OneHotArrays, Statistics, Test
using Random: Xoshiro, AbstractRNG
using Base: @kwdef
using Zygote: withgradient
using LuxDeviceUtils: cpu_device, gpu_device
using Optimisers, ComponentArrays
using Statistics: mean
using LinearAlgebra: diag
using Distributions: Normal
using DifferentialEquations: Heun, Tsit5

In [3]:
@info "Running similarity test..."
function check_phase(matrix)
    in_phase = diag(matrix)
    anti_phase = diag(matrix, convert(Int, round(n_x / 2)))

    v1 = reduce(*, map(x -> x > 1.0 - epsilon, in_phase))
    v2 = reduce(*, map(x -> x < -1.0 + epsilon, anti_phase))
    return v1, v2
end

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mRunning similarity test...


check_phase (generic function with 1 method)

In [4]:
using DifferentialEquations

In [5]:
n_x = 101
n_y = 101
n_vsa = 1
epsilon = 0.10
repeats = 10
epsilon = 0.025
solver_args = Dict(:adaptive => true, :abstol => 1e-6, :reltol => 1e-6)
spk_args = SpikingArgs(t_window = 0.01, 
                    threshold = 0.001,
                    solver=Tsit5(), 
                    solver_args = solver_args)
tspan = (0.0, repeats*1.0)
tbase = collect(tspan[1]:spk_args.dt:tspan[2])

1001-element Vector{Float64}:
  0.0
  0.01
  0.02
  0.03
  0.04
  0.05
  0.06
  0.07
  0.08
  0.09
  0.1
  0.11
  0.12
  ⋮
  9.89
  9.9
  9.91
  9.92
  9.93
  9.94
  9.95
  9.96
  9.97
  9.98
  9.99
 10.0

In [6]:
phase_x = reshape(range(-1.0, 1.0, n_x), (1, n_x, n_vsa)) |> collect
phase_y = reshape(range(-1.0, 1.0, n_y), (1, n_y, n_vsa)) |> collect

1×101×1 Array{Float64, 3}:
[:, :, 1] =
 -1.0  -0.98  -0.96  -0.94  -0.92  -0.9  …  0.9  0.92  0.94  0.96  0.98  1.0

In [7]:
sims = similarity_outer(phase_x, phase_y, dims= 2, reduce_dim=1)[1,1,:,:]
v1, v2 = check_phase(sims)

(true, true)

In [10]:
import .PhasorNetworks: phase_to_train

In [16]:
function phase_to_train(phases::AbstractArray; spk_args::SpikingArgs, repeats::Int = 1, offset::Real = 0.0)
    shape = phases |> size
    indices = collect(CartesianIndices(shape)) |> vec
    times = phase_to_time(phases, spk_args=spk_args, offset=offset) |> vec

    if repeats > 1
        n_t = times |> length
        offsets = repeat(0:repeats-1, inner=n_t)
        times = repeat(times, repeats) .+ offsets
        indices = repeat(indices, repeats)
    end

    train = SpikeTrain(indices, times, shape, offset)
    return train
end

phase_to_train (generic function with 1 method)

In [17]:
st_x = phase_to_train(phase_x, spk_args = spk_args, repeats = repeats)
st_y = phase_to_train(phase_y, spk_args = spk_args, repeats = repeats)

Spike Train: (1, 101, 1) with 1010 spikes.

In [18]:
sims_2 = stack(similarity_outer(st_x, st_y, dims=2, reduce_dim=3, tspan=tspan, spk_args = spk_args));

In [19]:
sims_2

1×1×665×101×101 Array{Float64, 5}:
[:, :, 1, 1, 1] =
 -1.0

[:, :, 2, 1, 1] =
 1.0

[:, :, 3, 1, 1] =
 1.0

;;; … 

[:, :, 663, 1, 1] =
 1.0

[:, :, 664, 1, 1] =
 0.999999761581428

[:, :, 665, 1, 1] =
 1.0

[:, :, 1, 2, 1] =
 -1.0

[:, :, 2, 2, 1] =
 0.999999761581428

[:, :, 3, 2, 1] =
 1.0

;;; … 

[:, :, 663, 2, 1] =
 0.998044968816231

[:, :, 664, 2, 1] =
 0.998044492212216

[:, :, 665, 2, 1] =
 0.998044968816231

[:, :, 1, 3, 1] =
 -1.0

[:, :, 2, 3, 1] =
 1.0

[:, :, 3, 3, 1] =
 1.0

;;; … 

[:, :, 663, 3, 1] =
 0.9922151174897635

[:, :, 664, 3, 1] =
 0.9922148795356591

[:, :, 665, 3, 1] =
 0.9922148795356591

;;;; … 

[:, :, 1, 99, 1] =
 -1.0

[:, :, 2, 99, 1] =
 -0.5000000000000002

[:, :, 3, 99, 1] =
 -0.5000000000000002

;;; … 

[:, :, 663, 99, 1] =
 0.9909032545536149

[:, :, 664, 99, 1] =
 0.9915155938183489

[:, :, 665, 99, 1] =
 0.9916005294263144

[:, :, 1, 100, 1] =
 -1.0

[:, :, 2, 100, 1] =
 -0.5000000000000002

[:, :, 3, 100, 1] =
 -0.5000000000000002

;;; … 

[:,