In [None]:
using Pkg
Pkg.activate("../Project.toml")
Pkg.instantiate()

In [None]:
using Dates
using Plots
using JLD2
using LsqFit, StatsBase
using QuadGK
using Parameters
using Distributions
using SpecialFunctions
using Random
using SparseArrays
using MultivariateStats

include(joinpath(@__DIR__,"..","src","Parameter_Setting.jl"))
include(joinpath(@__DIR__,"..","src","Dynamics.jl"))
include(joinpath(@__DIR__,"..","src","reversal_learning.jl"))
include(joinpath(@__DIR__,"..","src","Babadi_Formula.jl"))
include(joinpath(@__DIR__,"..","src","Plot.jl"))

cent(X) = X .- mean(X,dims=1)
corr(X,N) = ((X' * X) / N) ./ sqrt.(diag(X'*X/N) * diag(X' * X/N)')
cov(X,N) = ((X' * X) / N)# ./ sqrt.(diag(X'*X/N) * diag(X' * X/N)')
palette_J = cgrad(:Blues,6,categorical=true)
palette_h = cgrad(:Oranges,6,categorical=true)
tmp_correlation_ϕ(y,ρ,f) = (1/2π) * quadgk(x -> (ϕ(x - Tf(f)) - f)*(ϕ(ρ*x + sqrt(1-ρ^2)*y-Tf(f))-f) * exp(-(x^2+y^2)/2), -5,5)[1]
correlation_ϕ(ρ,f) =  quadgk(y -> tmp_correlation_ϕ(y,ρ,f), -5,5)[1]/(f*(1-f))
angle_correct(θ) = θ > 0 ? θ : θ + 2π

angle_correct (generic function with 1 method)

### reversal learning

### Fig. 5E

In [None]:
#1h55m
Random.seed!(1)
p.t_end = 10.
p.r = 0.5
p.dt = 0.1

p.Ns = 100
p.K = 100
p.P = 1
L = zeros(p.P)
L[1] = +1
S = zeros(p.Ns,p.P)
S[:,1] = [ones(Int(p.Ns/2)); zeros(Int(p.Ns/2))] * sqrt(2)
#S[:,2] = [zeros(Int(p.Ns/2)); ones(Int(p.Ns/2))] * sqrt(2)
ϕ(x) = x > 0 ? 1 : 0
p.f = 0.1
p.model = :dense_normal
N_simu = 3000; #500

p.γW = 10^(-1.)
p.αW = 1.0


γlist = [10^x for x in -1.:0.25:1.]
σlist = [2^x for x in -1.:1.:1.]
Nclist = [10^3] #[10^x for x in 2:3]
λlist = [2^0.] #[2^x for x in -2:2:2]

t1 = now()
println(t1)
gL_λσ = Array{Any}(undef,2,length(γlist),length(σlist),length(Nclist),length(λlist))
for i_γ in 1:length(γlist), i_σ in 1:length(σlist), i_Nc in 1:length(Nclist), i_λ in 1:length(λlist)
    p.γh = γlist[i_γ]; p.γJ = γlist[i_γ]; p.Nc = Nclist[i_Nc]; p.λ = λlist[i_λ]
    
    p.σh = σlist[i_σ];p.σJ = 0.;
    gL_λσ[1,i_γ,i_σ,i_Nc,i_λ] = reversal_learning(p,N_simu)

    p.σh = 0.;p.σJ = σlist[i_σ];
    gL_λσ[2,i_γ,i_σ,i_Nc,i_λ] = reversal_learning(p,N_simu)
end
t2 = now()
println(t2)
println(t2-t1)  


jldsave("../data/reversal_learning_1odor_$(now()).jld2";gL_λσ,p,S,L,γlist,σlist,Nclist,λlist)

2025-07-29T21:47:41.523
2025-07-29T23:42:56.189
6914666 milliseconds


### Fig. 5F

In [None]:
#3h17m
Random.seed!(1)
p.t_end = 10.
p.r = 0.5

p.Ns = 100
p.K = 100
p.P = 2
L = zeros(p.P)
L[1] = +1; L[2] = -1
S = zeros(p.Ns,p.P)
S[:,1] = [ones(Int(p.Ns/2)); zeros(Int(p.Ns/2))] * sqrt(2)
S[:,2] = [zeros(Int(p.Ns/2)); ones(Int(p.Ns/2))] * sqrt(2)
ϕ(x) = x > 0 ? 1 : 0
p.f = 0.1
p.model = :dense_normal
N_simu = 3000; #500

p.γW = 10^(-1.)
p.αW = 1.0

γlist = [10^x for x in -1.:0.25:1.]
σlist = [2^x for x in -1.:1.:1.]
Nclist = [10^3]#[10^x for x in 2:3]
λlist = [2^0.] #[2^x for x in -2:2:2]

t1 = now()
println(t1)
gL_λσ = Array{Any}(undef,2,length(γlist),length(σlist),length(Nclist),length(λlist))
for i_γ in 1:length(γlist), i_σ in 1:length(σlist), i_Nc in 1:length(Nclist), i_λ in 1:length(λlist)
    p.γh = γlist[i_γ]; p.γJ = γlist[i_γ]; p.Nc = Nclist[i_Nc]; p.λ = λlist[i_λ]
    
    p.σh = σlist[i_σ];p.σJ = 0.;
    gL_λσ[1,i_γ,i_σ,i_Nc,i_λ] = reversal_learning(p,N_simu)

    p.σh = 0.;p.σJ = σlist[i_σ];
    gL_λσ[2,i_γ,i_σ,i_Nc,i_λ] = reversal_learning(p,N_simu)
end
t2 = now()
println(t2)
println(t2-t1)  


jldsave("../data/reversal_learning_2odor_$(now()).jld2";gL_λσ,p,S,L,γlist,σlist,Nclist,λlist)

2025-07-29T23:43:17.290
2025-07-30T03:00:29.491
11832201 milliseconds


### Fig. S5B

In [None]:
#1h8m
Random.seed!(1)
p.t_end = 10.
p.r = 0.5

p.Ns = 100
p.K = 100
p.P = 2
L = zeros(p.P)
L[1] = +1; L[2] = -1
S = zeros(p.Ns,p.P)
S[:,1] = [ones(Int(p.Ns/2)); zeros(Int(p.Ns/2))] * sqrt(2)
S[:,2] = [zeros(Int(p.Ns/2)); ones(Int(p.Ns/2))] * sqrt(2)
ϕ(x) = x > 0 ? 1 : 0
p.f = 0.1
p.model = :dense_normal
N_simu = 3000; #500

p.γW = 10^(-1.)
p.αW = 1.0

γlist = [10^x for x in -1.:0.25:1.]
σlist = [2.]#[2^x for x in -1.:1.:2.]
Nclist = [10^x for x in 2:3]
λlist = [2^0.] #[2^x for x in -2:2:2]

t1 = now()
println(t1)
gL_λσ = Array{Any}(undef,2,length(γlist),length(σlist),length(Nclist),length(λlist))
for i_γ in 1:length(γlist), i_σ in 1:length(σlist), i_Nc in 1:length(Nclist), i_λ in 1:length(λlist)
    p.γh = γlist[i_γ]; p.γJ = γlist[i_γ]; p.Nc = Nclist[i_Nc]; p.λ = λlist[i_λ]
    
    p.σh = σlist[i_σ];p.σJ = 0.;
    gL_λσ[1,i_γ,i_σ,i_Nc,i_λ] = reversal_learning(p,N_simu)

    p.σh = 0.;p.σJ = σlist[i_σ];
    gL_λσ[2,i_γ,i_σ,i_Nc,i_λ] = reversal_learning(p,N_simu)
end
t2 = now()
println(t2)
println(t2-t1)  


jldsave("../data/reversal_learning_2odor_Nc_$(now()).jld2";gL_λσ,p,S,L,γlist,σlist,Nclist,λlist)

2025-07-30T03:00:49.251
2025-07-30T04:09:11.196
4101945 milliseconds


### Fig. S5A

In [None]:
#5h6m
Random.seed!(1)
p.t_end = 10.
p.r = 0.5

p.Ns = 100
p.K = 100
p.P = 2
L = zeros(p.P)
L[1] = +1; L[2] = -1
S = zeros(p.Ns,p.P)
S[:,1] = [ones(Int(p.Ns/2)); zeros(Int(p.Ns/2))] * sqrt(2)
S[:,2] = [zeros(Int(p.Ns/2)); ones(Int(p.Ns/2))] * sqrt(2)
ϕ(x) = x > 0 ? 1 : 0
p.f = 0.1
p.model = :dense_normal
N_simu = 3000; #500

p.γW = 10^(-1.)
p.αW = 1.0

γlist = [10^x for x in -1.:0.25:1.]
σlist = [2.]#[2^x for x in -1.:1.:2.]
Nclist = [10^3]#[10^x for x in 2:3]
λlist = [2^x for x in -2:2:2.]

t1 = now()
println(t1)
gL_λσ = Array{Any}(undef,2,length(γlist),length(σlist),length(Nclist),length(λlist))
for i_γ in 1:length(γlist), i_σ in 1:length(σlist), i_Nc in 1:length(Nclist), i_λ in 1:length(λlist)
    p.γh = γlist[i_γ]; p.γJ = γlist[i_γ]; p.Nc = Nclist[i_Nc]; p.λ = λlist[i_λ]
    
    p.σh = σlist[i_σ];p.σJ = 0.;
    gL_λσ[1,i_γ,i_σ,i_Nc,i_λ] = reversal_learning(p,N_simu)

    p.σh = 0.;p.σJ = σlist[i_σ];
    gL_λσ[2,i_γ,i_σ,i_Nc,i_λ] = reversal_learning(p,N_simu)
end
t2 = now()
println(t2)
println(t2-t1)  


jldsave("../data/reversal_learning_2odor_λ_$(now()).jld2";gL_λσ,p,S,L,γlist,σlist,Nclist,λlist)

2025-07-30T04:09:31.129
2025-07-30T09:15:23.584
18352455 milliseconds
