In [None]:
using JLD2
using AxisArrays
using Formatting
using Plots
using Colors
using Interact
using Statistics
using MAT

# Collect the data

In [None]:
turndata = matread(raw"./Data/turn_rates.m")

In [None]:
coherences = [0:-0.1:-1;0.1:0.1:1];

In [None]:
positive_turns = convert(Array{Union{Missing, Float64}},turndata["all_positive_turns"]);
negative_turns = convert(Array{Union{Missing, Float64}},turndata["all_negative_turns"]);
for t in [positive_turns, negative_turns]
    t[((~).(isfinite.(t)))] .= missing
end

In [None]:
mean_turns_r, mean_turns_l = [[mean(skipmissing(tr[:,i_coh])) for i_coh in 1:length(coherences)]
                        for tr in [positive_turns, negative_turns]];

In [None]:
coherence_order = sortperm(coherences);
coherences_sorted = coherences[coherence_order];

Put everything on to one side (symmetry)

In [None]:
trial_duration = 12;
mean_turn_rate = (mean_turns_l[coherence_order] .+
                  mean_turns_r[coherence_order[end:-1:1]]) ./
                 (2*trial_duration);

# Constraining the integrator model 

Under the assumption that the integrator output is a Poisson rate

We will have to posit an baseline rate (at coherence 0), and allow for inhibtion. The baseline rate can be read out directly from the data

In [None]:
r_b = mean_turn_rate[length(coherences)÷2+1]
turn_rate_deviation = mean_turn_rate .- r_b;

Steady-state equation of the integrator model is: $ r_L = L s_l^P + R s_r^P + r_b $

In [None]:
function sim_turn_rate(c_l, c_r, P, coh)
    if coh > 0
        return  c_l*(coh^P)
    else
        return -c_r*((-coh)^P)
    end
end

Let's use the Optim package to find the 3 parameters. It is a fit of a power law with the same exponent and different coefficients for positive and negative sides

In [None]:
using Optim

In [None]:
function sim_turn_rate_p1(c_l, c_r,coh)
    if coh > 0
        return  c_l*(coh)
    else
        return -c_r*((-coh))
    end
end

In [None]:
initial_params = [0.1, 0.1, 1.0];

In [None]:
errf = param -> sum(
    ((sim_turn_rate).(param..., coherences_sorted) -
        turn_rate_deviation).^2)
od1 = OnceDifferentiable(errf, initial_params);

In [None]:
res = optimize(errf, initial_params, BFGS());

In [None]:
sim_rate = (sim_turn_rate).(res.minimizer..., coherences_sorted);

scatter(coherences_sorted, turn_rate_deviation .+ r_b, label="data")
plot!(coherences_sorted, sim_rate .+ r_b, label="model", legend=:bottomright)
xlabel!("Coherence")
ylabel!("Turns per second in the stimulus direction")
#savefig("Results/turn_rates_per_coherence.pdf")

In [None]:
L, R, P = res.minimizer

$\tau$s will be determined by the transition experiment.

In [None]:
I = L
C = -R

# Figure out the remaining parameter from the transition experiments

In [None]:
@load "Data/mean_transition_responses.jld2" mn_responses

In [None]:
current_coherence = 0.6
previous_coherences = [-1.0, -0.6, -0.3, 0.3, 0.6, 1.0];

exp_variants = [(pr, current_coherence) for pr in previous_coherences]

all_variants = [(pr, cc) for pr in previous_coherences for cc in [0.3, 0.6, 1.0]];
            
# A helper function to convert coherences to integers for quick indexing
intcoh = c->round(Int64, c*10)

In [None]:
colors = parse.(Colorant, ["#5F4772", "#8A7AAB", "#AFA8CC", "#809B3C", "#6B7C35", "#485727"]);

In [None]:
k = AxisArrays.axes(mn_responses[10,10])[1]
time_exp = k.val

plot()
for (i_prev, previous) in enumerate(previous_coherences)
    plot!(time_exp, mn_responses[intcoh.(previous), intcoh(current_coherence)],
        color=colors[i_prev],
          label="prev. coh $(previous)")
end
plot!(legend=:bottomright)

## Full integrator model

In [None]:
stim_duration = 20
turn_rate_1 = 3.0 # number of turns per second for coherence 1
dt_sim = 0.005
n_sim_pre = n_sim_post = round(Int64, stim_duration/dt_sim)
n_sim = n_sim_pre*2
sim_time = (0:n_sim_pre*2-1)*dt_sim

In [None]:
"Function with makes time series of coherences for an experiment variant"
function coherences_from_variant_int(coh_pre, coh_post)
    coh_pos = zeros(Int16, n_sim) 
    coh_neg = zeros(Int16, n_sim)
    for (coh, slice) in zip([coh_pre, coh_post], [1:n_sim_pre, n_sim_pre+1:n_sim])
        if coh>0
            coh_pos[slice] .= round(Int16, coh*10)
        else
            coh_neg[slice] .= round(Int16, -coh*10)
        end
    end
    return coh_pos, coh_neg
end

In [None]:
include("data_preparation.jl")
include("integration_models_optim.jl")

Exact values for the integrator rates

In [None]:
function turn_angle(I, C, P, τI, τC, variant, turn_angle)
    coh_pos, coh_neg = coherences_from_variant_int(variant...);
    cohs = collect(0:0.1:1.0);
    int_left = independent_integrator_model(I, C, P, τI, τC, coh_pos, coh_neg, dt_sim, 0.0, cohs)
    int_right = independent_integrator_model(I, C, P, τI, τC, coh_neg, coh_pos, dt_sim, 0.0, cohs)

    trial_duration = 12

    mid = length(int_left) ÷ 2
    n_take = round(Int64, trial_duration/dt_sim)
    
    angle_total = cumsum(int_left[mid-n_take:mid+n_take]
                        .-int_right[mid-n_take:mid+n_take]).*(turn_angle*dt_sim)
    return angle_total .- angle_total[n_take]
end

function turn_angle_exact(I, C, P, τI, τC, variant, turn_angle)
    coh_pre, coh_post = variant
    angle_total = zeros(n_sim)
    sm, sp = (coh > 0 ? coh^P : -(-coh)^P for coh in [coh_pre, coh_post])
    t_pre = (0:n_sim_pre-1)*dt_sim
    angle_total[1:n_sim_pre] = (turn_angle*(I-C)*sm).*(t_pre.-t_pre[end])
    angle_total[(n_sim_pre+1):end] = @. turn_angle*(sp*(I - C) * t_pre -
                                     (sp-sm) * (I * τI * (1 - exp.(-t_pre / τI)) -
                                                C * τC * (1 - exp.(-t_pre / τC))))
    trial_duration = 12
    n_take = round(Int64, trial_duration/dt_sim)
    return angle_total[n_sim_pre-n_take:n_sim_pre+n_take]
end

In [None]:
trial_duration = 12
n_take = Int(trial_duration/dt_sim)
sim_time_cut = sim_time[n_sim_pre-n_take:n_sim_pre+n_take] .- sim_time[n_sim_pre]

In [None]:
@manipulate for τS=0:0.05:2.0, τD=0.1:0.05:2.0, turn_amount=10:60
    plot()
    for (i_prev, ev) in enumerate(exp_variants)
        ta = turn_angle_exact(I, C, P, τS, τD, ev, turn_amount)
        plot!(time_exp, mn_responses[intcoh.(ev)],
            c=colors[i_prev], lw=0.5, legend=:none)
        plot!(sim_time_cut, ta, c=colors[i_prev], linestyle=:dash)
    end
    ylims!(-75,75)
    xlabel!("Time [s]")
    ylabel!("Angle turned [°]")
end

A guess at paramters from the interactive display:

In [None]:
initial_params = [1.0, 2.0, 30]

Optimize the 3 parameters ($\beta$, $\tau$ and turn angle)

In [None]:
current_coherence=1.0

In [None]:
exp_variants = [(pr, cu) for pr in previous_coherences for cu in [0.3, 0.6, 1.0]];

In [None]:
function fit_mistake(params)
    τI, τC, turn_amount = params
    total_error = 0.0
    for ev in exp_variants
        ta = turn_angle_exact(I, C, P, τI, τC, ev, turn_amount)[1:end-1]
        total_error += sum((ta-mn_responses[round.(Int64,(ev[1]*10, ev[2]*10))]).^2)
        
    end
    return total_error
end

# We need to define the function as being once-differentiable to 
# use box-limitied optimization
funcopt = OnceDifferentiable(fit_mistake, initial_params);

In [None]:
res_turn = optimize(funcopt, 
    [0.01,0.01,10], [20.0, 20.0, 70], initial_params, Fminbox(BFGS()))

In [None]:
τI, τC, turn_amount = res_turn.minimizer

In [None]:
println(format("Optimal paramters are r_b: {:.2f}, I: {:.2f}, C: {:.2f}, P: {:.2f}, τI: {:.2f}, τC: {:.2f}, turn angle: {:.2f}", r_b, I, C, P, τI, τC, turn_amount))

## The error landscape around the two τ values

In [None]:
deltarange = -0.8:0.1:2.5
fit_errors = [
    fit_mistake([τI+dτI, τC+dτC, turn_amount])
    
    for dτC in deltarange, dτI in deltarange
];

In [None]:
heatmap(deltarange .+ τI, deltarange .+ τC, fit_errors, aspect_ratio=1)
#savefig("Results/tauerrorspace_big.png")

### Zoomed

In [None]:
deltarange = -0.2:0.05:0.2
fit_errors = [
    fit_mistake([τI + dτI, τC + dτC, turn_amount])
    
    for dτC in deltarange, dτI in deltarange
];
heatmap(deltarange .+ τI, deltarange .+ τC, fit_errors, aspect_ratio=1)

# Plotting the fit result for coherence 0.6 after the transition

In [None]:
red_variants = [(pr, 0.6) for pr in previous_coherences];

In [None]:
ei_θ = mod2pi.(atan.(I, C).-(π/4)) |> θ -> (θ > π ? 2*π - θ : θ)

τw = (τI .* abs.(I) + τC.* abs.(C))./(abs.(I)+ abs.(C));

In [None]:
plot(size=(500,500), grid=:none, tick_direction=:out)
for (i_prev, ev) in enumerate(red_variants)
        ta = turn_angle(I, C, P, τI, τC, ev, turn_amount)
        plot!(time_exp, mn_responses[intcoh.(ev)],
            c=colors[i_prev], lw=0.5, legend=:none)
        plot!(sim_time_cut, ta, c=colors[i_prev], linestyle=:dash)
end
ylims!(-100,100)
xlabel!("Time [s]")
ylabel!("Angle turned [°]")
#savefig("Results/Figure_4/poisson.pdf")

# Simulate the model for a run of experiment 26 and plot the turns

## Load a sequence

In [None]:
using DataStructures

In [None]:
@load "./Data/exp_26_stimsequence.jld2" t stimdata stim_map

In [None]:
# A simplified version of reconstruction of coherence sequence
# which does not require the cell
function coherence_sequence(stim_number, timing, map)
    n_t_ups = timing.n_frames_trial*timing.n_upsample
    coh_pos = zeros(UInt8, n_t_ups)
    coh_neg = zeros(UInt8, n_t_ups)

    for i in 1:n_t_ups
        # figure out the time
        t = (i-1)*timing.dt_sim

        # in which plane of the planes where the ROI is
        # is the current t?
        i_plane = floor(Int64, t / timing.trial_duration)+1

        t_in_plane = t - (i_plane-1)*timing.trial_duration
        idx_plane = round(Int64, t_in_plane/timing.dt_stim)+1

        if 0 <= idx_plane <= size(stim_number[i_plane], 2)
            val = stim_number[i_plane][idx_plane]
            if val > 0.
                v1 = round(Int64, val)
                coh = get(map, v1, 0.0)
                if coh > 0
                    coh_pos[i] = round(UInt8, coh*10)
                else
                    coh_neg[i] = round(UInt8, -coh*10)
                end
            end
        end
    end


    return coh_pos, coh_neg
end

In [None]:
stim_L, stim_R = coherence_sequence(stimdata, t, stim_map)

In [None]:
include("poisson_model.jl")

In [None]:
cohs = collect(0.0:10.0);
function simulate_poisson(stim_L, stim_R, I, C, P, τI, τC)
    int_left  = independent_integrator_model(I, C, P, τI, τC, stim_L, stim_R, t.dt_sim, 0.0, cohs)
    int_right = independent_integrator_model(I, C, P, τI, τC, stim_R, stim_L, t.dt_sim, 0.0, cohs)
    spikes_l = simulate_poisson(int_left  .+ r_b, t.dt_sim)
    spikes_r = simulate_poisson(int_right .+ r_b, t.dt_sim)
    return int_left, int_right, spikes_l, spikes_r
end

In [None]:
int_left, int_right, spikes_l, spikes_r = simulate_poisson(stim_L, stim_R, I, C, P, τw, τw);

In [None]:
using DataFrames

In [None]:
df_int = DataFrame(t=(0:length(int_left)-1)*t.dt_sim, int_left=int_left, int_right=int_right);

In [None]:
color_l, color_r = coherence_colors =parse.(Colorant, [
"#B07382",
"#738C89" ] )

In [None]:
dy = -0.4
scatter(spikes_l, zero(spikes_l), markershape=:vline, markerstrokecolor=color_l)
plot!((1:length(int_left))*t.dt_sim, int_left, color=color_l)

scatter!(spikes_r, fill(dy, length(spikes_r)), markershape=:vline, markerstrokecolor=color_r)
plot!((1:length(int_left))*t.dt_sim, int_right.+ dy, color=color_r, legend=nothing, xlims=(0, length(int_left)*t.dt_sim),
    tick_direction=:out, grid=nothing, size=(1000,400))

# Investigate the fit to the spurts experiment

In [None]:
n_fish = 54;
const spurt_durations = [1, 2, 3, 4, 6, 8, 10];
const spurt_coherences = [0.3, 0.6, 1.0];

In [None]:
@save "./data/suprt_turn_rates.jld2" turn_rates

In [None]:
dt_sim = 0.05

In [None]:
function spurt_average_rate(coherence, spurt_duration, I, C, P, τI, τC)
    n_t = round(Int, spurt_duration / dt_sim)
    coh_pows = collect(0:0.1:1)
    coh_left = fill(round(Int, coherence*10), n_t) 
    coh_right = zero(coh_left)
    int_left =  independent_integrator_model(I, C, P, τI, τC, coh_left, coh_right, dt_sim, 0.0, coh_pows)
    int_right = independent_integrator_model(I, C, P, τI, τC, coh_right, coh_left, dt_sim, 0.0, coh_pows)
    return mean(int_left), mean(int_right)
end

In [None]:
spurt_exp_turn_rates = [mean(turn_rates[coh],dims=3)[1:2,:,1] for coh in spurt_coherences];
    

In [None]:
function error_spurts(params)
    r_b, I, C, P, τI, τC = params
    
    dif = 0.0
    for (i_coh, (coh, exp_rates)) in enumerate(zip(spurt_coherences, spurt_exp_turn_rates))
        rates = spurt_average_rate.(coh, spurt_durations, I, C, P, τI, τC)
        dif += sum((first.(rates) .+ r_b .-exp_rates[1,:]).^2)
        dif += sum((last.(rates) .+ r_b .-exp_rates[2,:]).^2)
    end
    return dif
end

initial_params = [0.1, I, C, P, τI, τC]

spurtopt = OnceDifferentiable(error_spurts, initial_params);

In [None]:
res_spurt = optimize(spurtopt, 
    [0.0, -10.0, -10.0, 0.01, 0.01, 0.01],
    [1.0, 10.0, 10.0, 4.0, 10.0, 10.0], initial_params, Fminbox(BFGS()));

In [None]:
sp_r_b, sp_I, sp_C, sp_P, sp_τI, sp_τC = res_spurt.minimizer;
println(format("Optimal paramters are r_b: {:.2f}, I: {:.2f}, C: {:.2f}, P: {:.2f}, τI: {:.2f}, τC: {:.2f}",
        sp_r_b, sp_I, sp_C, sp_P, sp_τI, sp_τC))

In [None]:
plot()
for (i_coh, coh) in enumerate(spurt_coherences)
    plot!(spurt_durations, spurt_exp_turn_rates[i_coh][1,:])
    plot!(spurt_durations, spurt_exp_turn_rates[i_coh][2,:])
    rates = spurt_average_rate.(coh, spurt_durations, sp_I, sp_C, sp_P, sp_τI, sp_τC )
        plot!(spurt_durations, first.(rates) .+ sp_r_b, color=RGB(0,0,0))
        plot!(spurt_durations, last.(rates).+ sp_r_b, color=RGB(0,0,0))

end
plot!(xlabel="spurt duration [s]", ylabel="turn rate [turns/s]", legend=nothing)

Agreement with parameters fitted through transition experiment

In [None]:
plot()
for coh in spurt_coherences
        plot!(spurt_durations, mean(turn_rates[coh],dims=3)[2,:,1])
        plot!(spurt_durations, mean(turn_rates[coh],dims=3)[1,:,1])
        rates = spurt_average_rate.(coh, spurt_durations, I, C, P, τI, τC )
            plot!(spurt_durations, first.(rates) .+ r_b, color=RGB(0,0,0))
            plot!(spurt_durations, last.(rates).+ r_b, color=RGB(0,0,0))
        
end
plot!(xlabel="spurt duration [s]", ylabel="turn rate [turns/s]")