In [16]:
using JLD2, FileIO
stats = load("../stats.jld2")

Dict{String, Any} with 7 entries:
  "rl90"   => Dict{Tuple{String, String}, Vector{T} where T}(("mod-5", "3DS")=>…
  "rl9"    => Dict{Tuple{String, String}, Vector{T} where T}(("mod-5", "3DS")=>…
  "rl360"  => Dict{Tuple{String, String}, Vector{T} where T}(("mod-5", "3DS")=>…
  "ls30"   => Dict{Tuple{String, String}, Vector{T} where T}(("mod-5", "3DS")=>…
  "roam"   => Dict{Tuple{String, String}, Vector{T} where T}(("mod-5", "3DS")=>…
  "ls"     => Dict{Tuple{String, String}, Vector{T} where T}(("mod-5", "3DS")=>…
  "roam90" => Dict{Tuple{String, String}, Vector{T} where T}(("mod-5", "3DS")=>…

In [12]:
statnames = sort!(collect(keys(stats)))
conditions = sort!(collect(reduce(∪, keys(d) for d in values(stats))))

6-element Vector{Tuple{String, String}}:
 ("N2", "4DS")
 ("N2", "NS")
 ("mod-5", "3DS")
 ("mod-5", "NS")
 ("tph-1", "4DS")
 ("tph-1", "NS")

In [40]:
using Statistics, Unzip
using StatsBase
using Plots

#stage_sp(cond, well, stage) = cond2well2traj[cond][well].speed[stage_frames(well, stage; stagedict)]
#stage_roam(well_i, stage) = roam_for_stage(exwells[well_i]..., trajs[well_i], stage)

function bins(ax, nbins)
    edges = round.(Int, range(first(ax), last(ax)+1, length=nbins+1))
    [edges[i]:edges[i+1]-1 for i in 1:nbins]
end
fin(x) = !ismissing(x) && isfinite(x)
bin_means(v::AbstractVector, nbins) = [mean(view(v,i)) for i in bins(eachindex(v), nbins)]
bin_mean_fin(v::AbstractVector, nbins) = [mean(filter(fin,view(v,i))) for i in bins(eachindex(v), nbins)]

function bin_means_weights(v, nbins)
    b = bins(eachindex(v), nbins)
    f = [filter(fin, view(v,i)) for i in b]
    means = mean.(f)
    weights = length.(f) ./ length.(b)
    (; means, weights)
end

using PaddedViews

function pad!(rows, val, α=1)
    len = maximum(length.(rows))
    for (i,row) in pairs(rows)
        d = len - length(row)
        nfront = round(Int, d*α)
        nback = d - nfront
        
        rows[i] = PaddedView(val, row, (1:len,), (nfront+1:len-nback,))
    end
    rows
end

function bin_means_multiple(conds, stat, stage, pad=false, α=0.0; nbins=250)
    #sort!(conds)
    st = [stats[stat][cond][stage] for cond in conds]
    cond_ends = cumsum(length.(st))
    stat_rows = reduce(vcat, sort.(st; by=length); init=[])
    isempty(stat_rows) && return nothing
    pad && pad!(stat_rows, missing, α)
    means, _ = unzip(bin_means_weights.(stat_rows, nbins))
    reduce(hcat, means)', stat_rows, cond_ends
end

bin_means_multiple (generic function with 3 methods)

# Activity Heatmaps

In [72]:
using Interact
@manipulate throttle=0.1 for conds = togglebuttons(conditions, multiple=true, value=[("N2", "NS"), ("N2", "4DS")]), 
                stat in statnames,
                stage = 1:5, nbins=50:5:150, 
                pad=true, α in slider(0:0.01:1, value=0),
                crange = rangepicker(0:0.01:1)
    sort!(conds)
    # stat_rows = reduce(vcat, 
    #                     sort(stats[stat][cond][stage]; by=length) for cond in conds; init=[])
    #sort!(stat_rows; by=length)
    M = bin_means_multiple(conds, stat, stage, pad, α; nbins)
    if M === nothing
        plot()
    else
        M, _, cond_ends = M
        Mmin, Mmax = extrema(filter(isfinite,M))
        cmin, cmax = Mmin .+ crange .* (Mmax-Mmin)
        heatmap(M, legend=false, cbar=true, clim=(cmin,cmax))
        #cond_ends = cumsum(length(stats[stat][cond][stage]) for cond in conds)
        hline!(cond_ends .+ 0.5, c="white")
        yticks!((midpoints([0;cond_ends]), join.(conds, " ")))
        #size(M)
    end

end

In [73]:
n2m(a) = replace(a, NaN=>missing)
finites(x) = filter(fin,x)

function spstats(stat; nbins=250)
    # multiplication with strong 0: 0⋆NaN==0
    a ⊛ b = iszero(a) || iszero(b) ? zero(promote_type(typeof(a),typeof(b))) : a*b

    t = midpoints(range(0,1,length=nbins+1))
    means, weights = unzip(bin_means_weights.(stat, nbins))
    bin_means = mapreduce(n2m, hcat, means)'
    μ = mean.(filter.(fin, eachcol(bin_means)))
    weights = reduce(hcat, weights)'
    
    # TODO right way to deal with missing values?
    μ_c = coalesce.(bin_means,NaN)
    μ_d = (μ_c .- μ') .⊛ weights ./ .√(sum(weights; dims=1).-1)
    Σ = μ_d' * μ_d

    μ, Σ, μ_d, bin_means, weights
end

@manipulate for stat in statnames,
                condition in conditions,
                stage in 1:5,
                nbins = 1:500,
                zero_mean = false
    μ, Σ, μ_d, bin_means, weights = spstats(stats[stat][condition][stage]; nbins)

    zero_mean && (bin_means .-= μ')
    q_low, med, q_high = unzip([quantile(filter(fin,c), (0.25,0.5,0.75)) for c in eachcol(bin_means)])
    d_plt = plot(med, ribbon = (med.-q_low, q_high.-med) , label="quartiles")
    plot!([mean(finites(c)) for c in eachcol(bin_means)], label="mean")
    plot!(reduce(hcat, quantile(filter(fin,c), [0.05,0.95]) for c in eachcol(bin_means))', 
                label=["5%" "95%"], c="black", alpha=0.2)
    plot(heatmap(bin_means), d_plt, layout=(2,1), size=(600,600))
end

# Distribution distances and correlations

In [74]:
using Distances
using StatsBase
using StatsBase: normalize

#cum_dist(h1, h2) = mean(abs(x-y) for (x,y) in zip(cumsum(normalize(h1,1)), cumsum(normalize(h2,1))))
#jsnorm(h1, h2) = js_divergence(normalize(h1,1), normalize(h2,1))
js_bits(x,y) = js_divergence(x,y) / log(2)
hellinger²(x,y) = hellinger(x,y)^2

n_time_bins = 150
n_stat_bins = 50

@manipulate for d in [hellinger, hellinger², js_bits], stat in statnames, 
                condition in conditions, stage in 1:5,
                n_time_bins = 1:300,
                n_stat_bins = 1:200

    μ, Σ, μ_d, bin_means, weights = spstats(stats[stat][condition][stage]; nbins=n_time_bins)
    #s = 0:0.025:1
    s = range(extrema(finites(bin_means))...; length=n_stat_bins+1)
    #s = -1.5:0.01:1

    @assert minimum(s) <= minimum(finites(bin_means))
    @assert maximum(s) >= maximum(finites(bin_means))

    W = reduce(hcat, normalize(fit(Histogram, finites(c), s), mode=:probability).weights for c in eachcol(bin_means))
    C = [d(c1, c2) for c1 in eachcol(W), c2 in eachcol(W)]

    a = 100
    b = size(bin_means,2) - a*minimum(finites(bin_means))
    function binned_mean_inset_plot!()
        q_low, med, q_high = unzip([quantile(filter(fin,c), (0.25,0.5,0.75)) for c in eachcol(bin_means)])
        plot!(b.+a.*med, ribbon = ((med.-q_low) .* a, (q_high.-med) .* a) , label="quartiles")
        plot!(b.+a.*μ, label="mean", c=1)
        plot!(reduce(hcat, b.+a.*quantile(finites(c), [0.05,0.95]) for c in eachcol(bin_means))', 
                    label=["5%" "95%"], c="black", alpha=0.2)
    end

    dist_d_hm = heatmap(C, aspect_ratio=1, clims=(0,1), 
            title="$(join(condition," ")) (n=$(size(bin_means,1))) stage $stage\n $stat, $d ($n_stat_bins bins)",
            #legend=:outerbottom,
            legend=false,
            )
    binned_mean_inset_plot!()

    dist_d_hist = histogram(vec(C), legend=false, yaxis=false, yticks=false, xlims=(0,1))

    S = corspearman(μ_d)
    cor_hm = heatmap(S, c=:diverging, 
                    title="spearman",
                    legend=false,
                    #legend=:outerbottom
                    clims=(-1,1), aspect_ratio=1, 
                    )
    binned_mean_inset_plot!()

    cor_hist = histogram(vec(S), legend=false, yaxis=false, yticks=false, xlims=(-1,1))
    vline!([0])

    plot(dist_d_hm, cor_hm, dist_d_hist, cor_hist,
        layout=grid(2,2, heights=(0.8,0.2)),
        size=(600,600))

    #savefig("$outdir/$stat-$stage-$n_time_bins-$d-$n_stat_bins-$(join(cond,"-")).png"); current()
end


In [28]:
using MultivariateStats

isfullrow(r) = all(!ismissing,r)
fullrows(a::AbstractMatrix{T}) where T = mapreduce(Vector{nonmissingtype(T)}, hcat, 
                            r for r in eachrow(a) if isfullrow(r))'
sort_rows_by(m, v) = mapreduce( first, hcat, sort!(collect(zip(eachrow(m), v)); by=last) )'


@manipulate for stat in statnames, condition in conditions, stage in 1:5, nbins=1:300,
                pc = slider(1:20, value=1)
    μ, Σ, μ_d, mean_mat, weights = spstats(stats[stat][condition][stage]; nbins)

    mean_mat_f = fullrows(mean_mat)

    pca_t = fit(PCA, mean_mat_f')
    Pt = projection(pca_t)
    Pt .*= sign(sum(Pt[:,1]))
    
    #m_plt = plot(mean(pca_t), lw=3, label="mean")
    m_plt = plot(mean(pca_t), lw=3, label="\$\\mu\$")
    v = principalvars(pca_t)[pc]
    plot!(m_plt, mean(pca_t), ribbon = √v*Pt[:,pc], label="", c=2, la=0, fa=0.2)
    #plot!(m_plt, mean(pca_t) .+ √v .* Pt[:,pc], label="±σ_$pc⋅PC_$pc", c=2, lw=1)
    plot!(m_plt, mean(pca_t) .+ √v .* Pt[:,pc], label="\$\\pm \\sigma_{$pc} \\mathrm{PC}_{$pc}\$", c=2, lw=1, legendfontsize=12)
    pca_plt = plot(Pt[:,pc], fillrange=0, label="", c=2)
    sorted_hm = heatmap(sort_rows_by(mean_mat_f, mean_mat_f * Pt[:,pc]), cbar=false)
    plot(plot(m_plt, pca_plt, layout=(2,1)), sorted_hm, layout=(1,2), size=(900,600))
end


In [75]:
isfullrow(r) = all(!ismissing,r)

fullrows(a::AbstractMatrix{T}) where T = mapreduce(Vector{nonmissingtype(T)}, hcat, 
                            r for r in eachrow(a) if isfullrow(r))'

@manipulate for stat in statnames, condition in conditions, stage in 1:5, nbins=1:300,
                by = togglebuttons(["worms", "time bins"]), show_mean = false, n_pcs = slider(1:20, value=4, label="PCs")
    μ, Σ, μ_d, mean_mat, weights = spstats(stats[stat][condition][stage]; nbins)
    bycols = by == "time bins"

    mean_mat_f = fullrows(mean_mat)
    x = bycols ? mean_mat_f : mean_mat_f'
    pca = fit(PCA, x)
    μ, P = mean(pca), projection(pca)
    P .*= sign(sum(P[:,1]))
    pcs = P[:,1:n_pcs]
    pc_labels = ["PC $i" for i=1:n_pcs]
    coeff_labels = ["c $i" for i=1:n_pcs]
    coeffs = (x .- μ)' * pcs
    col_vecs   = bycols ? show_mean ? hcat(μ, pcs) : pcs : coeffs
    col_labels = bycols ? show_mean ? vcat(["μ"], pc_labels) : pc_labels : coeff_labels
    # mean is at the end so the colors of PCAs are not affected and match between subplots
    row_vecs   = bycols ? coeffs : show_mean ? hcat(pcs, μ) : pcs
    ncol, nrow = size(col_vecs,2), size(row_vecs,2)

    pvars_plt = bar(principalvars(pca)[1:n_pcs] ./ tvar(pca), legend=false, 
                     group = 1:n_pcs)
    plot!(cumsum(principalvars(pca)[1:n_pcs] ./ tvar(pca)), marker=true)
    plot(
        plot(heatmap(mean_mat_f, title=condition),
            #bar(mean(pca), orientation=:horizontal, axis=false, ticks=false, title="μ"),
            #(bar(P[:,i], orientation=:horizontal, axis=false, ticks=false, title="PC$i") for i=1:n_pcs)...,
            (bar(v, orientation=:horizontal, axis=false, ticks=false; title) 
                    for (title,v) in zip(col_labels, eachcol(col_vecs))
            )...,
            legend=false, link=:y,
            layout=grid(1,ncol+1,widths=normalize([10;fill(1,ncol)],1))),
        plot(
            #plot(mean_mat_f' * P[:,1:n_pcs], label=(1:n_pcs)', legend=false), 
            plot(row_vecs, label=(1:n_pcs)', legend=false), 
            pvars_plt, 
            layout=(1,2)),
        layout=(2,1), size=(600,600))
end


In [122]:
using ColorSchemes
self_indexed(x) = indexin(x,unique(x))

symheatmap(m; kwargs...) = heatmap(m, clims=(-1,1).*maximum(abs.(m)), c=:diverging; kwargs...)
#@manipulate for stat in statnames, condition in conditions, stage in 1:5, nbins=1:300
@manipulate for stat in statnames, pca_condition in conditions, 
        conds = togglebuttons(conditions, multiple=true, value=[("N2", "NS"), ("N2", "4DS")]),
        stage in 1:5, nbins=150, n_pcs in 1:20, sortby in slider(1:20, value=1),
        palettename in dropdown(sort!([k for (k,v) in colorschemes if v.category ∈ ("colorbrewer2", "tableau", "seaborn")]),
                                value=:tableau_10)
    pal = palette(palettename)
    
    μ, Σ, μ_d, mean_mat, weights = spstats(stats[stat][pca_condition][stage]; nbins)
    mean_mat_f = fullrows(mean_mat)
    pca = fit(PCA, mean_mat_f'; pratio=1)

    conds = [pca_condition] ∪ conds
    M, stat_rows, cond_ends = bin_means_multiple(conds, stat, stage; nbins)
    row_labels = inverse_rle(conds, diff([0; cond_ends]))
    #@show row_labels
    i_fin = [all(fin,r) for r in eachrow(M)]
    M_f, row_labels_f = M[i_fin,:], row_labels[i_fin]
    #heatmap(M_f)
    M_t = MultivariateStats.transform(pca, M_f')
    σ = .√principalvars(pca)
    M_t_σ = M_t ./ σ
    
    # use numerical groups to ensure they are ordered correctly
    row_labels_f_i = indexin(row_labels_f, conds)
    
    hists = scatter(M_t[sortby,:], -0.1row_labels_f_i .+ 0.01 .* randn.(), 
            group=row_labels_f_i, alpha=0.2, label="", color_palette = pal, title="PC $sortby")
    stephist!(M_t[sortby,:], nbins=15, lw=3, normed=true, 
            group=row_labels_f_i, label=permutedims(join.(conds, " ")), 
            c=(1:length(conds))', color_palette = pal,
            legend=:outerright)
    pc_sort_rows(m, pc_i) = permutedims(mapreduce( first, hcat, sort(collect(zip(eachrow(m), M_t[pc_i,:])); by=last) ))
    label_hm = heatmap(pc_sort_rows(row_labels_f_i, sortby), yticks=false, xticks=false,
            c = cgrad(pal[1:length(conds)]))
    plot(plot(heatmap(pc_sort_rows(M_f, sortby)), 
         label_hm,
         symheatmap(pc_sort_rows(M_t[1:n_pcs,:]', sortby), yticks=false), 
         symheatmap(pc_sort_rows(M_t_σ[1:n_pcs,:]', sortby), yticks=false), 
         layout=grid(1,4, widths=normalize([5,1,5,5],1)), 
         cbar=false, link=:y
         ),
         hists, layout=(2,1), size=(900,600), )
end

In [121]:
@manipulate for stat in statnames, conds = togglebuttons(conditions, multiple=true, value=[("N2", "NS"), ("N2", "4DS")]),
                stage in 1:5, nbins=1:300, n_pcs in 1:8
    sort!(conds)
    _, _, _, mean_mats, _ = unzip( spstats(stats[stat][condition][stage]; nbins) for condition in conds )
    mean_mats_f = [fullrows(m) for m in mean_mats]
    mean_mat_f = reduce(vcat, mean_mats_f)
    src_cond_ends = cumsum(size.(mean_mats_f,1))
    src_cond   = reduce(vcat, [fill(i,len) for (i,len) in enumerate(size.(mean_mats_f,1))]) 
    
    #size(mean_mat_f), length(src_cond)
    pca = fit(PCA, mean_mat_f'; pratio=1)
    mm_t = MultivariateStats.transform(pca, mean_mat_f')
    σ = .√principalvars(pca)
    mm_t_σ = mm_t ./ σ
    
    hm = heatmap(mean_mat_f, xticks=false)
    hline!(src_cond_ends .+ 1/2, label="", c="white")
    yticks!((midpoints([0;src_cond_ends]), join.(conds, " ")))

    plot( hm,
          (scatter(mm_t[i,:], mm_t[j,:], xlabel="PC $i", ylabel="PC $j", alpha=0.3, legend=false, group=src_cond) 
                    for i=1:n_pcs for j=i+1:n_pcs)..., 
             size=(800,600), cbar=false )
end

In [118]:
using LinearAlgebra
spearman_dist(x,y) = 1 - corspearman(x,y)
kendall_dist(x,y)  = 1 - corkendall(x,y)

normrows!(x) = x ./= copy(norm.(eachrow(x), 1)) # copy to prevent broadcast fusion

@manipulate for stat in statnames, cond1 in conditions, cond2 in conditions,
                stage in 1:5, nbins=1:300, d in (euclidean, cityblock, corr_dist, spearman_dist, kendall_dist),
                normalized=false
    _, _, _, mean_mat1, _ = spstats(stats[stat][cond1][stage]; nbins)
    _, _, _, mean_mat2, _ = spstats(stats[stat][cond2][stage]; nbins)

    mean_mat_f1 = fullrows(mean_mat1)
    mean_mat_f2 = fullrows(mean_mat2)
    if normalized
        normrows!(mean_mat_f1)
        normrows!(mean_mat_f2)
    end
    
    function plots(x)
        pca = fit(PCA, x'; pratio=1)
        xt = MultivariateStats.transform(pca, x')
        xt_σ = xt ./ .√principalvars(pca)
        
        (heatmap(x, normed=true, xticks=false), 
            heatmap(pairwise(euclidean, eachrow(x)))
            #histogram(vec(pairwise(euclidean, eachrow(x))), normed=true),
            #histogram(vec(pairwise(euclidean, eachcol(xt[1:n_pcs,:]))), normed=true),
            #histogram(vec(pairwise(euclidean, eachcol(xt_σ[1:n_pcs,:]))), normed=true)
        )
    end

    pca1, pca2 = fit(PCA, mean_mat_f1'; pratio=1), fit(PCA, mean_mat_f2'; pratio=1)
    mm_t1 = MultivariateStats.transform(pca1, mean_mat_f1')
    mm_t_σ1 = mm_t1 ./ .√principalvars(pca1)
    mm_t2 = MultivariateStats.transform(pca2, mean_mat_f2')
    mm_t_σ2 = mm_t2 ./ .√principalvars(pca2)


    cond_ends = cumsum([size(mean_mat_f1,1), size(mean_mat_f2,1)])
    #hm = heatmap(rand(10,10))
    hm = heatmap(vcat(mean_mat_f1, mean_mat_f2), xticks=false, cbar=false)
    hline!(cond_ends .+ 1/2, label="", c="white")
    yticks!((midpoints([0;cond_ends]), join.([cond1, cond2], " ")))
    
    step_plt = stephist(vec(pairwise(d, eachrow(mean_mat_f1))), normed=true, label=join(cond1," "))
    stephist!(vec(pairwise(d, eachrow(mean_mat_f2))), normed=true, label=join(cond2," "))
    plot(#plot(
         #   (heatmap(x, xticks=false) for x in (mean_mat_f1, mean_mat_f2))...,
         #   Iterators.flatten(plots(x, n_pcs) for x in (mean_mat_f1, mean_mat_f2))...,
         #   layout=(2,1), link=:x, legend=false, cbar=false, size=(800,500)
         #), 
        hm,
        step_plt )#, layout=(2,1) )
        
end


##### 