In [71]:
using Profile
using ProfileView
using StatProfilerHTML
using BenchmarkTools


using LinearAlgebra
using FFTW
using Statistics
using JLD2
using NPZ
using ProgressMeter
# 

function sinusoidal(a, f, sr, t, theta=0, DC=0)
    delta_i = 1 / sr
    f2pi = f * 2 * π
    nu = [DC + (a * sin(f2pi * i * delta_i + theta)) for i in 0:(t-1)]
    return nu
end

function tfr_estimate_size(n_epochs, n_channels, n_taps, n_freqs, n_times)
    element_size = 16 # bytes for ComplexF64
    total_elements = n_epochs * n_channels * n_taps * n_freqs * n_times
    total_bytes = total_elements * element_size
    total_gb = total_bytes / (1024^3) # Convert bytes to GB
    return total_gb
end

function weights_estimate_size(n_taps, n_freqs, n_times)
    element_size = 8 # bytes for Float64
    total_elements = n_taps * n_freqs * n_times
    total_bytes = total_elements * element_size
    total_gb = total_bytes / (1024^3) # Convert bytes to GB
    return total_gb
end

function Ws_estimate_size(n_taps, freqs, sfreq, n_cycles)
    element_size = 16 # bytes for ComplexF64
    total_elements = 0
    for k = 1:n_freqs
        f = freqs[k]
        t_win = n_cycles / f
        len_t = Int(ceil(t_win * sfreq))
        total_elements += n_taps * len_t
    end
    total_bytes = total_elements * element_size
    total_gb = total_bytes / (1024^3) # Convert bytes to GB
    return total_gb
end

function fft_estimate_size(a, b, c)
    element_size = 16 # bytes for ComplexF64
    total_elements = a * b * c
    total_bytes = total_elements * element_size
    total_gb = total_bytes / (1024^3) # Convert bytes to GB
    return total_gb
end

function psd_estimate_size(n_epochs, n_channels, n_freqs, n_times)
    element_size = 8 # bytes for Float64
    total_elements = n_epochs * n_channels * n_freqs * n_times
    total_bytes = total_elements * element_size
    total_gb = total_bytes / (1024^3) # Convert bytes to GB
    return total_gb
end

function coh_estimate_size(n_epochs, n_channels, n_freqs)
    element_size = 8 # bytes for Float64
    total_elements = n_epochs * n_channels * n_channels * n_freqs
    total_bytes = total_elements * element_size
    total_gb = total_bytes / (1024^3) # Convert bytes to GB
    return total_gb
end

function scale_dimensions(data, n_taps, freqs, sfreq, n_cycles; print=false, max_gb=0, reserve_gb=0)
    system_mem = Sys.total_memory() / (1024^3)
    if max_gb == 0
        max_gb =  system_mem- reserve_gb # Leave X GB for other stuff
    end
    current_mem = Sys.free_memory() / (1024^3)
    if current_mem <= max_gb
        @warn "Current free memory: $current_mem GB\nDesired max: $max_gb GB\nSystem total memory: $(system_mem) GB\nMemory available is less than desired max!\nAttempting to use available memory!"
        max_gb = current_mem - reserve_gb
    end
    if(max_gb <= 0)
        error("Not enough memory available!")
    end
    
    n_epoch_org, n_channels, n_times = size(data)

    n_freqs = length(freqs)
    t_win = n_cycles / minimum(freqs)
    max_len = Int(ceil(t_win * sfreq))
    nfft = n_times + max_len - 1
    nfft = next_fast_len(nfft)
    
    data_size = Base.summarysize(data) / (1024^3)
    weights = weights_estimate_size(n_taps, n_freqs, n_times)
    Ws = Ws_estimate_size(n_taps, freqs, sfreq, n_cycles)
    fft_Ws = fft_estimate_size(n_taps, n_freqs, nfft)
    fft_X = fft_estimate_size(n_epoch_org, n_channels, nfft)
    coherence_mean = coh_estimate_size(n_epoch_org, n_channels, 1)
    
    current_mem = Sys.free_memory() / (1024^3)
    
    static_total = data_size + weights + Ws + fft_Ws + fft_X + coherence_mean
    
    if static_total >= max_gb || static_total >= current_mem
        println("Static calculations will exceed memory!")
        println("---------------------------------")
        println("Current free memory: $current_mem GB")
        println("---------------------------------")
        println("Data: $(data_size) GB")
        println("Weights: $weights GB")
        println("Ws: $Ws GB")
        println("FFT Ws: $fft_Ws GB")
        println("FFT X: $fft_X GB")
        println("Coherence Mean: $coherence_mean GB")
        println("---------------------------------")
        println("Total: $static_total GB")
        println("Exceeds maximum memory limit of $max_gb GB or current free memory")
        error("Memory limit exceeded")
    end
    n_epochs = copy(n_epoch_org)
    tfr_size = tfr_estimate_size(n_epochs, n_channels, n_taps, n_freqs, n_times)
    psd_per_epoch = psd_estimate_size(n_epochs, n_channels, n_freqs, n_times)
    coherence = coh_estimate_size(n_epoch_org, n_channels, n_freqs)
    coherence_mean_small = 0
    
    dynamic_total = tfr_size + psd_per_epoch + coherence + coherence_mean_small    
    while dynamic_total + static_total >= max_gb && n_epochs > 0
        n_epochs -= 1
        tfr_size = tfr_estimate_size(n_epochs, n_channels, n_taps, n_freqs, n_times)
        psd_per_epoch = psd_estimate_size(n_epochs, n_channels, n_freqs, n_times)
        coherence = coh_estimate_size(n_epochs, n_channels, n_freqs)
        coherence_mean_small = coh_estimate_size(n_epochs, n_channels, 1)
        dynamic_total = tfr_size + psd_per_epoch + coherence + coherence_mean_small
    end

    if n_epochs == 0
        current_mem = Sys.free_memory() / (1024^3)
        println("Can not even compute one epoch with current memory!")
        println("---------------------------------")
        println("Current free memory: $current_mem GB")
        println("System total memory: $system_mem GB")
        println("--------------Static arrays--------------")
        println("Data: $(data_size) GB")
        println("Weights: $weights GB")
        println("Ws: $Ws GB")
        println("FFT Ws: $fft_Ws GB")
        println("FFT X: $fft_X GB")
        println("Coherence Mean: $coherence_mean GB")
        println("Static Total: $static_total GB")
        println("-----Dynamically calculated arrays-----")
        println("TFR: $tfr_size GB")
        println("PSD: $psd_per_epoch GB")
        println("Coherence: $coherence GB")
        println("Coherence Mean (small): $coherence_mean_small GB")
        println("Dynamic Total: $dynamic_total GB")
        println("---------------------------------")
        println("Total: $dynamic_total + $static_total GB")
        println("Exceeds maximum memory limit of $max_gb GB")
        error("Memory limit exceeded")
    end



    if print
        current_mem = Sys.free_memory() / (1024^3)
        println("Can be computed with batches of $n_epochs epochs")
        println("total batches: $(ceil(n_epoch_org/n_epochs))")
        println("--------------Static arrays--------------")
        println("Data: $(data_size) GB")
        println("Weights: $weights GB")
        println("Ws: $Ws GB")
        println("FFT Ws: $fft_Ws GB")
        println("FFT X: $fft_X GB")
        println("Coherence Mean: $coherence_mean GB")
        println("Static Total: $static_total GB")
        println("-----Dynamically calculated arrays-----")
        println("TFR: $tfr_size GB")
        println("PSD: $psd_per_epoch GB")
        println("Coherence: $coherence GB")
        println("Coherence Mean (small): $coherence_mean_small GB")
        println("Dynamic Total: $dynamic_total GB")
        println("---------------------------------")
        println("Total: $(dynamic_total+static_total) GB")
        println("Desired max: $max_gb GB")
        println("---------------------------------")        
        println("Current free memory: $current_mem GB")
        println("System total memory: $system_mem GB")
    end
    return n_epochs
end

function tril_indices(n)::Array{Tuple{Int,Int},1}
    pairs = Array{Tuple{Int,Int},1}(undef, n * (n - 1) ÷ 2)
    q = 1
    for x in 1:n
        for y in (x+1):n
            pairs[q] = (x, y)
            q += 1
        end
    end
    return pairs
end


function next_fast_len(target::Int)::Int
    """
    Find the next fast size of input data to `fft`, for zero-padding, etc.

    Returns the next composite of the prime factors 2, 3, and 5 which is
    greater than or equal to `target`. (These are also known as 5-smooth
    numbers, regular numbers, or Hamming numbers.)

    Parameters
    ----------
    target : Int
        Length to start searching from. Must be a positive integer.

    Returns
    -------
    out : Int
        The first 5-smooth number greater than or equal to `target`.
    """
    # Precomputed Hamming numbers (5-smooth numbers) for quick lookup
    hams = [
        8, 9, 10, 12, 15, 16, 18, 20, 24, 25, 27, 30, 32, 36, 40, 45, 48, 50,
        54, 60, 64, 72, 75, 80, 81, 90, 96, 100, 108, 120, 125, 128, 135, 144,
        150, 160, 162, 180, 192, 200, 216, 225, 240, 243, 250, 256, 270, 288,
        300, 320, 324, 360, 375, 384, 400, 405, 432, 450, 480, 486, 500, 512,
        540, 576, 600, 625, 640, 648, 675, 720, 729, 750, 768, 800, 810, 864,
        900, 960, 972, 1000, 1024, 1080, 1125, 1152, 1200, 1215, 1250, 1280,
        1296, 1350, 1440, 1458, 1500, 1536, 1600, 1620, 1728, 1800, 1875, 1920,
        1944, 2000, 2025, 2048, 2160, 2187, 2250, 2304, 2400, 2430, 2500, 2560,
        2592, 2700, 2880, 2916, 3000, 3072, 3125, 3200, 3240, 3375, 3456, 3600,
        3645, 3750, 3840, 3888, 4000, 4050, 4096, 4320, 4374, 4500, 4608, 4800,
        4860, 5000, 5120, 5184, 5400, 5625, 5760, 5832, 6000, 6075, 6144, 6250,
        6400, 6480, 6561, 6750, 6912, 7200, 7290, 7500, 7680, 7776, 8000, 8100,
        8192, 8640, 8748, 9000, 9216, 9375, 9600, 9720, 10000
    ]

    if target <= 6
        return target
    end

    # Check if target is already a power of 2
    if (target & (target - 1)) == 0
        return target
    end

    # Quick lookup for small sizes
    if target <= hams[end]
        idx = searchsortedfirst(hams, target)
        return hams[idx]
    end

    # Function to compute the bit length of an integer
    bit_length(x::Int) = x <= 0 ? 0 : floor(Int, log2(x)) + 1

    match = typemax(Int)  # Initialize with maximum possible integer
    p5 = 1
    while p5 < target
        p35 = p5
        while p35 < target
            # Ceiling integer division
            quotient = cld(target, p35)
            p2 = 2^bit_length(quotient - 1)
            N = p2 * p35
            if N == target
                return N
            elseif N < match
                match = N
            end
            p35 *= 3
            if p35 == target
                return p35
            end
        end
        if p35 < match
            match = p35
        end
        p5 *= 5
        if p5 == target
            return p5
        end
    end
    if p5 < match
        match = p5
    end
    return match
end

### Tapers
function _extend(M::Int, sym::Bool)::Tuple{Int,Bool}
    # Extend window by 1 sample if needed for DFT-even symmetry
    if !sym
        return M + 1, true
    else
        return M, false
    end
end


function _fftautocorr(x::AbstractMatrix{<:Float64})::Array{Float64, 2}
    """
    tested vs python:
    isapprox(x_fft, py_x_fft, atol=1e-12) == true
    isapprox(py_cxy, cxy, atol=1e-12) == true
    """
    N = size(x, 2)
    use_N = next_fast_len(2 * N - 1)
    padded = zeros(Float64, size(x, 1), use_N)
    padded[:, 1:N] .= x
    plan = plan_rfft(padded, 2)
    x_fft = plan * padded
    cxy = irfft(x_fft .* conj.(x_fft), use_N, 2)[:, 1:N]
    return cxy
end

function py_dpss(M::Int, NW::Float64, normalization_type::Int, Kmax::Int; sym::Bool=true)::Tuple{Array{Complex{Float64},2},Union{Array{Float64,1},Float64}}
    """
    Compute the Discrete Prolate Spheroidal Sequences (DPSS).

    Parameters
    ----------
    M : Int
        Window length.
    NW : Float64
        Standardized half bandwidth corresponding to 2*NW = BW/f0 = BW*M*dt
        where dt is taken as 1.
    normalization_type : Int
        Normalization of the DPSS windows. Must be one of 1, 2, or 3.
        1: No normalization.
        2: Approximate normalization.
        3: Subsample normalization.
    Kmax : Int
        Number of DPSS windows to return. Must be less than or equal to M and greater than 0.
        If 1, return only a single window of shape (M,)
        instead of an array of windows of shape (Kmax, M).
    sym : Bool, optional
        When true (default), generates a symmetric window, for use in filter design.
        When false, generates a periodic window, for use in spectral analysis.

    return_ratios : Bool, optional
        If true, also return the concentration ratios in addition to the windows.

    Returns
    -------
    windows : Array{Float64, 2} or Array{Float64, 1}
        The DPSS windows. Will be 1D if `Kmax` is nothing.
    ratios : Array{Float64, 1} or Float64, optional
        The concentration ratios for the windows. Only returned if
        `return_ratios` evaluates to true. Will be scalar if `Kmax` is nothing.
    """
    known_norms = (1, 2, 3)
    if normalization_type ∉ known_norms
        error("normalization_type must be one of $known_norms, got $normalization_type")
    end
    if Kmax === 1
        singleton = true
    else
        singleton = false
    end
    if !(0 < Kmax <= M)
        error("Kmax must be greater than 0 and less than or equal to M")
    end
    if NW >= M / 2.0
        error("NW must be less than M/2.")
    end
    if NW <= 0
        error("NW must be positive")
    end

    M, needs_trunc = _extend(M, sym)
    W = NW / M
    nidx = collect(0:M-1)
    d = ((M - 1 .- 2 .* nidx) ./ 2.0) .^ 2 .* cos.(2pi * W)
    e = nidx[2:end] .* (M .- nidx[2:end]) ./ 2.0
    # Use SymTridiagonal for efficient eigenvalue computation
    T = SymTridiagonal(d, e)
    evals = eigvals(T, M-Kmax+1:M);
    evecs = eigvecs(T, evals);
    # Extract the largest Kmax eigenvalues and eigenvectors
    windows = evecs[:, end:-1:1]'
    # Correct sign conventions
    fix_even = sum(windows[1:2:end, :], dims=2) .< 0
    windows[1:2:end, :][fix_even[:, 1], :] .*= -1

    # # Correct signs for even-indexed windows
    thresh = max(1e-7, 1.0 / M)
    for (i, w) in enumerate(eachrow(windows[2:2:end, :]))
        idx = findfirst(x -> x^2 > thresh, w)
        if idx !== nothing && w[idx] < 0
            windows[2i, :] *= -1
        end
    end

    # Compute concentration ratios
    dpss_rxx = _fftautocorr(windows)
    r = 4 * W * sinc.(2 * W .* (nidx))
    r[1] = 2 * W
    ratios = dpss_rxx * r

    if singleton
        ratios = ratios[1]
    end
    # Apply normalization if needed
    if normalization_type != 1
        max_abs = maximum(abs, windows)
        windows ./= max_abs
        if iseven(M)
            if normalization_type == 2
                correction = M^2 / (M^2 + NW)
            elseif normalization_type == 3
                s = rfft(windows[1, :])
                shift = -(1 - 1.0 / M) .* (1:Int(M / 2))
                s[2:end] .*= 2 .* exp.(-im * π .* shift)
                correction = M / sum(real(s))
            end
            windows .*= correction
        end
    end

    if needs_trunc
        windows = windows[:, 1:end-1]
    end
    if singleton
        windows = windows[1, :]
    end
    return windows, ratios
end

function compute_tapers(N::Int, n_taps::Int, freqs::AbstractArray{<:Real}, mt_bandwidth::Real, n_cycles::Int, sfreq::Int; zero_mean::Bool=true)::Tuple{Matrix{Vector{ComplexF64}},Array{Float64,3}}
    n_freqs = length(freqs)
    weights = Array{Float64,3}(undef, n_taps, n_freqs, N)
    Ws = Matrix{Vector{ComplexF64}}(undef, n_taps, n_freqs)
    sp5 = sqrt(0.5)
    # Loop over frequencies first
    Threads.@threads for k in eachindex(freqs)
        f = freqs[k]
        t_win = n_cycles / f
        len_t = Int(ceil(t_win * sfreq))

        t = collect(0:1/sfreq:t_win-(t_win % (1 / sfreq) == 0 ? 1 / sfreq : 0)) # exclude last value if it fits exactly
        t_centered = t .- t_win / 2.0

        # Precompute oscillation and taper
        oscillation = exp.(2.0 * im * pi * f .* t_centered)

        taper, e = py_dpss(len_t, mt_bandwidth / 2, 1, n_taps, sym=false)
        weights[:, k, :] .= sqrt.(e)

        for m = 1:n_taps
            # Use @view to avoid copying taper column
            Wk = oscillation .* @view taper[m, :]

            if zero_mean  # To make it zero mean
                real_offset = mean(Wk)
                Wk .-= real_offset
            end

            # Normalize Wk
            Wk /= sp5 * norm(Wk)

            # Store Wk in preallocated Ws
            Ws[m, k] = Wk
        end
    end
    return Ws, weights
end
### end tapers


function _get_nfft(Ws::Matrix{Vector{ComplexF64}}, X::AbstractArray{<:Float64})::Int
    max_len = maximum([length(Wk) for Wk in Ws])
    n = last(size(X))
    nfft = n + max_len - 1
    # @show nfft
    nfft = next_fast_len(nfft)
    return nfft
end

function coh(s_xx::AbstractMatrix{Float64}, s_yy::AbstractMatrix{Float64}, s_xy::AbstractMatrix{ComplexF64})::Array{Float64}
    # Compute the numerator: absolute value of the mean of s_xy along the last dimension
    con_num = abs.(mean(s_xy, dims=ndims(s_xy)))

    # Compute the denominator: square root of the product of means of s_xx and s_yy along the last dimension
    con_den = sqrt.(mean(s_xx, dims=ndims(s_xx)) .* mean(s_yy, dims=ndims(s_yy)))

    # Calculate coherence as the element-wise division of numerator by denominator
    coh = con_num ./ con_den
    return coh
end
function coh(s_xx::Array{Float64, 2}, s_yy::Array{Float64, 2}, s_xy::Array{ComplexF64, 2})::Array{Float64}
    # Compute the numerator: absolute value of the mean of s_xy along the last dimension
    con_num = abs.(mean(s_xy, dims=ndims(s_xy)))

    # Compute the denominator: square root of the product of means of s_xx and s_yy along the last dimension
    con_den = sqrt.(mean(s_xx, dims=ndims(s_xx)) .* mean(s_yy, dims=ndims(s_yy)))

    # Calculate coherence as the element-wise division of numerator by denominator
    coh = con_num ./ con_den
    return coh
end


# Precompute FFTs of Ws
function precompute_fft_Ws(Ws, nfft)
    n_taps, n_freqs = size(Ws)
    fft_Ws = zeros(ComplexF64, n_taps, n_freqs, nfft) # preallocated padded array
    
    # tried threading but it was slightly slower
    for taper_idx = 1:n_taps
        for freq_idx = 1:n_freqs
            # Ws are different lengths
            fft_Ws[taper_idx, freq_idx, 1:length(Ws[taper_idx, freq_idx])] .= Ws[taper_idx, freq_idx]
        end
    end
    # plan fft!
    p = plan_fft!(fft_Ws, 3)
    return p * fft_Ws
end

# Precompute FFTs of X
function precompute_fft_X(X, nfft)
    n_epochs, n_channels, n_times = size(X)
    fft_X = zeros(ComplexF64,n_epochs, n_channels, nfft)
    fft_X[:,:,1:n_times] .= data
    p = plan_fft!(fft_X, 3)
    return p * fft_X
end


function compute_tfr!(tfr::Array{ComplexF64, 5}, fft_X::Array{ComplexF64, 3}, fft_Ws::Array{ComplexF64, 3}, Ws_lengths::Array{Int64, 2})
    batch_size, n_channels, nfft = size(fft_X)
    n_taps, n_freqs, _ = size(fft_Ws)
    _, _, _, _, n_times = size(tfr)
    
    # Precompute sizes, start_indices, and end_indices
    sizes = n_times .+ Ws_lengths .- 1
    start_indices = floor.(Int, (sizes .- n_times) ./ 2) .+ 1
    end_indices = start_indices .+ n_times .- 1
    
    nthreads = Threads.nthreads()
    temp_arrays = [Array{ComplexF64}(undef, batch_size, n_channels, nfft) for _ in 1:nthreads]
    fft_plans = [plan_ifft!(temp_arrays[i], 3) for i in 1:nthreads]
    
    # Thread over frequencies
    @inbounds @showprogress desc="Computing TFRs" Threads.@threads for freq_idx = 1:n_freqs
        thread_id = Threads.threadid()
        temp = temp_arrays[thread_id]
        ifft_plan = fft_plans[thread_id]
        
        # Loop over tapers
        for taper_idx = 1:n_taps
            fft_W = fft_Ws[taper_idx, freq_idx, :]  # Current fft_W
            Ws_length = Ws_lengths[taper_idx, freq_idx]
            ret_size = n_times + Ws_length - 1
            
            # Compute start and end indices for slicing
            start = start_indices[taper_idx, freq_idx]
            end_time = end_indices[taper_idx, freq_idx]
            
            # Compute the product and inverse FFT in-place
            temp .= fft_X .* reshape(fft_W, 1, 1, nfft)  # Broadcasting over first two dims
            temp .= ifft_plan * temp  # In-place inverse FFT
            
            # Assign the centered result to tfr
            tfr[:, :, taper_idx, freq_idx, :] .= temp[:, :, start:end_time]
        end
    end
    
    return tfr
end

# function compute_tfr!(tfr::Array{ComplexF64, 5}, fft_X::Array{ComplexF64, 3}, fft_Ws::Array{ComplexF64, 3}, Ws_lengths::Array{Int64, 2})::Array{ComplexF64,5}
#     batch_size, n_channels, n_taps, n_freqs, n_times = size(tfr)

#     @showprogress desc="Computing TFRs" Threads.@threads for idx in 1:(n_taps*n_freqs*batch_size*n_channels)
#     # @showprogress desc="Computing TFRs" for idx in 1:(n_taps*n_freqs*batch_size*n_channels)
#         # Compute indices from idx
#         # @show idx
#         taper_idx = ((idx - 1) ÷ (n_freqs * batch_size * n_channels)) + 1
#         rem1 = (idx - 1) % (n_freqs * batch_size * n_channels)
#         freq_idx = (rem1 ÷ (batch_size * n_channels)) + 1
#         rem2 = rem1 % (batch_size * n_channels)
#         epoch_idx = (rem2 ÷ n_channels) + 1
#         channel_idx = (rem2 % n_channels) + 1

#         fft_W = @view fft_Ws[taper_idx, freq_idx, :]
#         W_size = Ws_lengths[taper_idx, freq_idx]
#         total_size = n_times + W_size - 1
#         ret_size = total_size

#         fx = @view fft_X[epoch_idx, channel_idx, :]
#         product = fx .* fft_W
#         ret = ifft(product)[1:ret_size]

#         # # # Center the result
#         start = Int(floor((ret_size - n_times) / 2)) + 1
#         end_time = start + n_times - 1
#         tfr[epoch_idx, channel_idx, taper_idx, freq_idx, :] .= ret[start:end_time]

#     end

#     return tfr
# end

# function compute_psd(batch_size::Int, n_channels::Int, n_freqs::Int, n_times::Int, tfrs::Array{ComplexF64,5}, weights::Array{Float64,3}, normalization::Array{Float64, 3})::Array{Float64,4}
#     psd_per_epoch = Array{Float64,4}(undef, batch_size, n_channels, n_freqs, n_times)
#     @showprogress desc = "Computing epoch's PSD..." Threads.@threads for idx = 1:(batch_size*n_channels)
#         # Compute epoch_idx and c_idx from idx
#         epoch_idx = div(idx - 1, n_channels) + 1
#         c_idx = mod(idx - 1, n_channels) + 1

#         # Perform the element-wise multiplication with broadcasting
#         psd = weights .* @view tfrs[epoch_idx, c_idx, :, :, :]

#         # Square magnitude (complex conjugate multiplication)
#         psd .= psd .* conj(psd)

#         # Sum across the first dimension (tapers)
#         psd = sum(real(psd), dims=1)

#         # Apply the normalization
#         psd .= psd .* normalization

#         # Update the psd_per_epoch array
#         psd_per_epoch[epoch_idx, c_idx, :, :] .= psd[1, :, :]
#     end
#     return psd_per_epoch
# end

function compute_psd!(psd_per_epoch::Array{Float64,4}, tfrs::Array{ComplexF64,5}, weights::Array{Float64,3}, normalization::Array{Float64, 2})::Array{Float64,4}
    batch_size, n_channels, n_tapers, n_freqs, n_times = size(tfrs)

    nthreads = Threads.nthreads()
    psd_arrays = [Array{ComplexF64}(undef, n_tapers, n_freqs, n_times) for _ in 1:nthreads]
    psd_sums = [Array{Float64}(undef, n_freqs, n_times) for _ in 1:nthreads]

    @inbounds @showprogress desc="Computing epoch's PSD..." for idx = 1:(batch_size * n_channels)
        thread_id = Threads.threadid()
        psd = psd_arrays[thread_id]
        psd_sum = psd_sums[thread_id]

        # Compute epoch_idx and channel_idx from idx
        epoch_idx = div(idx - 1, n_channels) + 1
        c_idx = mod(idx - 1, n_channels) + 1

        # Extract the current tfr slice
        tfr_view = @view tfrs[epoch_idx, c_idx, :, :, :]

        # Perform the element-wise multiplication
        @. psd = weights * tfr_view

        # Compute the squared magnitude
        @. psd = psd * conj(psd)

        # Sum across the first dimension (tapers)
        psd_sum .= 0.0
        @inbounds for t = 1:n_tapers
            @views psd_sum .= psd_sum .+ real(psd[t, :, :])
        end

        # Apply the normalization
        @. psd_sum = psd_sum * normalization

        # Update the psd_per_epoch array
        psd_per_epoch[epoch_idx, c_idx, :, :] .= psd_sum
    end
    return psd_per_epoch
end

function compute_coh_mean(epochs::Int, n_channels::Int, n_freqs::Int, n_pairs::Int, pairs::Array{Tuple{Int, Int},1},tfrs::Array{ComplexF64,5}, psd_per_epoch::Array{Float64,4}, weights::Array{Float64,3}, normalization::Array{Float64, 3})::Array{Float64,4}
    coherence = Array{Float64,4}(undef, epochs, n_channels, n_channels, n_freqs)
    @showprogress desc = "Computing Coherence..." Threads.@threads for idx in 1:epochs*n_pairs
        # Calculate the epoch index and pair index
        epoch_idx = div(idx - 1, n_pairs) + 1
        pair_idx = mod(idx - 1, n_pairs) + 1
        x, y = pairs[pair_idx]
        # println("Epoch: $epoch_idx, Pair: ($x, $y)")
        # Now perform your operations
        w_x = @view tfrs[epoch_idx, x, :, :, :]
        w_y = @view tfrs[epoch_idx, y, :, :, :]
        s_xy = sum(weights .* w_x .* conj(weights .* w_y), dims=1)  # sum over tapers
        s_xy = s_xy .* normalization

        s_xx = @view psd_per_epoch[epoch_idx, x, :, :]
        s_yy = @view psd_per_epoch[epoch_idx, y, :, :]
        coh_value = coh(s_xx, s_yy, s_xy[1, :, :])

        coherence[epoch_idx, y, x, :] .= coh_value
    end
    return mean(coherence, dims=ndims(coherence))
end

# t = 32
# sr = 32
# f = 2

# v = sinusoidal(10, f, sr, t * 4, 0)
# w = sinusoidal(10, f, sr, t * 4, π / 4)
# y = sinusoidal(10, f, sr, t * 4, π / 2)
# z = sinusoidal(10, f, sr, t * 4, π)

# data = Array{Float64}(undef, 2, 4, 128)

# data[1, :, :] = hcat(v, w, y, z)'
# data[2, :, :] = hcat(-v, -w, -y, -z)';

# freqs = collect(2:15) # inclusive of end 
# n_freqs = length(freqs)
# mt_bandwidth = 4
# n_taps = floor(Int, mt_bandwidth - 1)
# n_cycles = 7
# sfreq = 32
# zero_mean = true


outputpath = "/media/dan/Data/git/network_mining/connectivity/julia_test/"
data = npzread("/media/dan/Data/git/network_mining/connectivity/julia_test/034_input.npy")
data = data[1:2, 1:10, :]

sfreq = 2048
freqs = collect(14:50)
zero_mean = true
n_freqs = length(freqs)
mt_bandwidth = 4
n_taps = floor(Int, mt_bandwidth - 1)
n_cycles = 7
n_epochs, n_channels, n_times = size(data)

batch_size = scale_dimensions(data, n_taps, freqs, sfreq, n_cycles, print=true, reserve_gb=60)
total_batches = ceil(n_epochs / batch_size)
if batch_size != n_epochs
    println("Data is too big for one pass!\nData will be computed in batches of $batch_size epochs. Total batches: $(total_batches)")
end


println("Making tapers...")
Ws, weights = compute_tapers(n_times, n_taps, freqs, mt_bandwidth, n_cycles, sfreq)
weights_squared = weights .^ 2
normalization = 2 ./ sum(real(weights .* conj(weights)), dims=1);
small_norm = dropdims(normalization; dims=1)

nfft = _get_nfft(Ws, data)

println("Precomputing FFTs of tapers and data...")
fft_Ws = precompute_fft_Ws(Ws, nfft);
fft_X = precompute_fft_X(data, nfft);
println("Done!")

# save(joinpath(outputpath, "034_pretasks.jld2"), "Ws", Ws, "weights", weights, "fft_Ws", fft_Ws, "fft_X", fft_X, "normalization", normalization)

Ws_lengths = [length(Wk) for Wk in Ws]

println("Preparing for computation...")
pairs = tril_indices(n_channels)
n_pairs = length(pairs)


tfr = Array{ComplexF64,5}(undef, batch_size, n_channels, n_taps, n_freqs, n_times);
compute_tfr!(tfr, fft_X, fft_Ws, Ws_lengths);

psd_per_epoch = Array{Float64,4}(undef, batch_size, n_channels, n_freqs, n_times);
compute_psd!(psd_per_epoch, tfr, weights, small_norm);


Can be computed with batches of 2 epochs
total batches: 1.0
--------------Static arrays--------------
Data: 0.00015264004468917847 GB
Weights: 0.00084686279296875 GB
Ws: 0.0008460581302642822 GB
FFT Ws: 0.003387451171875 GB
FFT X: 0.0006103515625 GB
Coherence Mean: 1.4901161193847656e-6 GB
Static Total: 0.0058448538184165955 GB
-----Dynamically calculated arrays-----
TFR: 0.03387451171875 GB
PSD: 0.005645751953125 GB
Coherence: 5.513429641723633e-5 GB
Coherence Mean (small): 0 GB
Dynamic Total: 0.039575397968292236 GB
---------------------------------
Total: 0.04542025178670883 GB
Desired max: 191.5341453552246 GB
---------------------------------
Current free memory: 237.50051498413086 GB
System total memory: 251.5341453552246 GB
Making tapers...
Precomputing FFTs of tapers and data...
Done!
Preparing for computation...


[32mComputing TFRs 100%|█████████████████████████████████████| Time: 0:00:00[39m


In [24]:
function compute_coh_mean1(epochs::Int, n_channels::Int, n_freqs::Int, n_pairs::Int, pairs::Array{Tuple{Int, Int},1}, tfrs::Array{ComplexF64,5}, psd_per_epoch::Array{Float64,4}, weights::Array{Float64,3}, normalization::Array{Float64, 2})::Array{Float64,4}

# Initialize coherence array
# coherence = zeros(Float64, epochs, n_channels, n_channels, n_freqs)
coherence = Array{Float64,4}(undef, epochs, n_channels, n_channels, n_freqs)

# Precompute weights squared
weights_squared = weights .^ 2  # size: (n_tapers, n_freqs, n_times)

n_tapers, n_freqs, n_times = size(weights)
nthreads = Threads.nthreads()

# Preallocate per-thread temporary arrays
s_xy_arrays = [zeros(ComplexF64, n_freqs, n_times) for _ in 1:nthreads]
coh_values = [zeros(Float64, n_freqs) for _ in 1:nthreads]

# s_xy_arrays = [Array{ComplexF64,2}(undef, n_freqs, n_times) for _ in 1:nthreads]
# coh_values = [Array{ComplexF64,1}(undef, n_freqs) for _ in 1:nthreads]

Threads.@threads for idx in 1:(epochs * n_pairs)
    thread_id = Threads.threadid()
    s_xy = s_xy_arrays[thread_id]
    coh_value = coh_values[thread_id]

    # Calculate the epoch index and pair index
    epoch_idx = div(idx - 1, n_pairs) + 1
    pair_idx = mod(idx - 1, n_pairs) + 1
    x, y = pairs[pair_idx]

    # Extract slices for the current epoch and channel pair
    w_x = @view tfrs[epoch_idx, x, :, :, :]
    w_y = @view tfrs[epoch_idx, y, :, :, :]

    # Reset s_xy to zero
    s_xy .= 0.0

    # Compute s_xy using in-place operations and avoid allocations
    @inbounds for t = 1:n_tapers
        @. s_xy .+= @view(weights_squared[t, :, :]) .* w_x[t, :, :] .* conj(w_y[t, :, :])
    end

    # Apply normalization
    s_xy .= s_xy .* normalization

    # Extract s_xx and s_yy
    s_xx = @view psd_per_epoch[epoch_idx, x, :, :]  # size: (n_freqs, n_times)
    s_yy = @view psd_per_epoch[epoch_idx, y, :, :]

    # Compute coherence values over frequencies
    @inbounds for f = 1:n_freqs
        s_xy_mean = mean(s_xy[f, :])
        s_xx_mean = mean(s_xx[f, :])
        s_yy_mean = mean(s_yy[f, :])
        coh_value[f] = abs(s_xy_mean) / sqrt(s_xx_mean * s_yy_mean)
    end

    # Update the coherence array
    coherence[epoch_idx, y, x, :] .= coh_value
end

# Compute the mean over freq
coherence_mean = mean(coherence, dims=ndims(coherence))
# coherence_mean = dropdims(coherence_mean; dims=1)  # Remove the singleton dimension

return coherence_mean  # size: (n_channels, n_channels, n_freqs)
end


compute_coh_mean1 (generic function with 1 method)

In [None]:


function compute_coh_mean3(epochs::Int, n_channels::Int, n_freqs::Int, n_pairs::Int, tfrs::Array{ComplexF64,5}, psd_per_epoch::Array{Float64,4}, weights_squared::Array{Float64,3}, normalization::Array{Float64, 3})::Array{Float64,4}
    coherence = Array{Float64,4}(undef, epochs, n_channels, n_channels, n_freqs)
    @showprogress desc = "Computing Coherence..." Threads.@threads for idx in 1:epochs*n_pairs
        # Calculate the epoch index and pair index
        epoch_idx = div(idx - 1, n_pairs) + 1
        pair_idx = mod(idx - 1, n_pairs) + 1
        x, y = pairs[pair_idx]
        # println("Epoch: $epoch_idx, Pair: ($x, $y)")
        # Now perform your operations
        w_x = @view tfrs[epoch_idx, x, :, :, :]
        w_y = @view tfrs[epoch_idx, y, :, :, :]
        s_xy =  sum(weights_squared .* w_x .* conj(w_y), dims=1)  # sum over tapers
        s_xy .*= normalization

        s_xx = @view psd_per_epoch[epoch_idx, x, :, :]
        s_yy = @view psd_per_epoch[epoch_idx, y, :, :]
        coh_value = coh(s_xx, s_yy, s_xy[1, :, :])

        # Copy to symmetric position
        coherence[epoch_idx, y, x, :] .= coh_value
    end
    return mean(coherence, dims=ndims(coherence))
end
compute_coh_mean3(batch_size, n_channels, n_freqs, n_pairs, tfr, psd_per_epoch, weights_squared, normalization)



function compute_coh_mean4(epochs::Int, n_channels::Int, n_freqs::Int, n_pairs::Int, tfrs::Array{ComplexF64,5}, psd_per_epoch::Array{Float64,4}, weights_squared::Array{Float64,3}, normalization::Array{Float64, 3})::Array{Float64,4}
    coherence = Array{Float64,4}(undef, epochs, n_channels, n_channels, n_freqs)
    @showprogress desc = "Computing Coherence..." Threads.@threads for idx in 1:epochs*n_pairs
        # Calculate the epoch index and pair index
        epoch_idx = div(idx - 1, n_pairs) + 1
        pair_idx = mod(idx - 1, n_pairs) + 1
        x, y = pairs[pair_idx]
        # println("Epoch: $epoch_idx, Pair: ($x, $y)")
        # Now perform your operations
        w_x = @view tfrs[epoch_idx, x, :, :, :]
        w_y = @view tfrs[epoch_idx, y, :, :, :]
        s_xy =  sum(weights_squared .* w_x .* conj(w_y), dims=1)  # sum over tapers
        s_xy .*= normalization

        s_xx = @view psd_per_epoch[epoch_idx, x, :, :]
        s_yy = @view psd_per_epoch[epoch_idx, y, :, :]
        coh_value = coh(s_xx, s_yy, s_xy[1, :, :])

        # Copy to symmetric position
        coherence[epoch_idx, y, x, :] .= coh_value
    end
    return mean(coherence, dims=ndims(coherence))
end
compute_coh_mean4(batch_size, n_channels, n_freqs, n_pairs, tfr, psd_per_epoch, weights_squared, normalization)

[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m


2×10×10×1 Array{Float64, 4}:
[:, :, 1, 1] =
 1.19058e7  0.571331  0.533617  0.479512  …  0.311142  0.305408  0.288305
 1.37196e7  0.657312  0.538955  0.43261      0.325697  0.307234  0.297361

[:, :, 2, 1] =
   3.81004e8  5.2985e8     0.906335  …  0.475408  0.500495  0.494158
 NaN          2.14286e-11  0.874699     0.206805  0.203231  0.24465

[:, :, 3, 1] =
 1.93103e5  1.69563e10    2.43856e10  …  0.540325  0.560476  0.561652
 2.93167e9  2.0671e10   NaN              0.311789  0.203429  0.195551

[:, :, 4, 1] =
   3.90184e11  5.42637e11  7.80375e11  …  0.501037  0.535014  0.538405
 NaN           5.80582e10  8.99243e11     0.416335  0.219975  0.161871

[:, :, 5, 1] =
 1.24865e13  0.0154135     2.49732e13  …  0.472128  0.454038  0.433354
 6.91224e11  8.64003e10  NaN              0.565884  0.450237  0.386541

[:, :, 6, 1] =
   3.99585e14  5.55737e14   3.5394e14  …  0.458761  0.476843  0.47256
 NaN           3.66121e-19  0.817593      0.550091  0.439358  0.352998

[:, :, 7, 1] =
 2.7096e15

In [49]:
function compute_coh_mean5(epochs::Int, n_channels::Int, n_freqs::Int, n_pairs::Int, tfrs::Array{ComplexF64,5}, psd_per_epoch::Array{Float64,4}, weights_squared::Array{Float64,3}, normalization::Array{Float64, 2})::Array{Float64,4}
    coherence = Array{Float64,4}(undef, epochs, n_channels, n_channels, n_freqs)
    @showprogress desc = "Computing Coherence..." Threads.@threads for idx in 1:epochs*n_pairs
        # Calculate the epoch index and pair index
        epoch_idx = div(idx - 1, n_pairs) + 1
        pair_idx = mod(idx - 1, n_pairs) + 1
        x, y = pairs[pair_idx]
        # println("Epoch: $epoch_idx, Pair: ($x, $y)")
        # Now perform your operations
        w_x = @view tfrs[epoch_idx, x, :, :, :]
        w_y = @view tfrs[epoch_idx, y, :, :, :]
        s_xy =  dropdims(sum(weights_squared .* w_x .* conj.(w_y), dims=1),dims=1)  # sum over tapers
        s_xy .*= normalization


        s_xx = @view psd_per_epoch[epoch_idx, x, :, :]
        s_yy = @view psd_per_epoch[epoch_idx, y, :, :]

        # Compute the numerator: absolute value of the mean of s_xy along the last dimension
        con_num = abs.(mean(s_xy, dims=2))

        # Compute the denominator: square root of the product of means of s_xx and s_yy along the last dimension
        con_den = sqrt.(mean(s_xx, dims=2) .* mean(s_yy, dims=2))

        # Calculate coherence as the element-wise division of numerator by denominator
        coh_value = con_num ./ con_den
        # coh_value = coh(s_xx, s_yy, s_xy[1, :, :])

        # Copy to symmetric position
        coherence[epoch_idx, y, x, :] .= coh_value
    end
    return mean(coherence, dims=4)
end
compute_coh_mean5(batch_size, n_channels, n_freqs, n_pairs, tfr, psd_per_epoch, weights_squared, small_norm)

[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m


2×10×10×1 Array{Float64, 4}:
[:, :, 1, 1] =
 1.16675e241  0.571331  0.533617  0.479512  …  0.311142  0.305408  0.288305
 2.58039e178  0.657312  0.538955  0.43261      0.325697  0.307234  0.297361

[:, :, 2, 1] =
 NaN  1.74686e169  0.906335  0.764978  …  0.475408  0.500495  0.494158
 NaN  1.25887e150  0.874699  0.603005     0.206805  0.203231  0.24465

[:, :, 3, 1] =
 1.29493e179  1.87596e227    1.09185e305  …  0.540325  0.560476  0.561652
 2.76009e275  1.15351e241  NaN               0.311789  0.203429  0.195551

[:, :, 4, 1] =
   1.85316e286  1.32528e251  -5.98394e259  …  0.501037  0.535014  0.538405
 NaN            3.06127e275   1.0256e297      0.416335  0.219975  0.161871

[:, :, 5, 1] =
 1.15201e179  3.06127e275     1.7602e174  …  0.472128  0.454038  0.433354
 1.32818e251  1.15062e-141  NaN              0.565884  0.450237  0.386541

[:, :, 6, 1] =
   1.02114e179  7.90001e46   1.28096e169  …  0.458761  0.476843  0.47256
 NaN            1.28031e169  6.85809e-53     0.550091  0.439358 

In [13]:
@benchmark compute_coh_mean(batch_size, n_channels, n_freqs, n_pairs, pairs, tfr, psd_per_epoch, weights, normalization)

[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32

BenchmarkTools.Trial: 13 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m368.201 ms[22m[39m … [35m422.713 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m7.76% … 2.01%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m389.048 ms               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m8.48%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m388.709 ms[22m[39m ± [32m 14.029 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m7.87% ± 1.80%

  [39m▁[39m [39m [39m [39m [39m▁[39m [39m [39m [39m [39m [39m▁[39m▁[39m▁[39m [39m [39m [39m [39m [39m [39m [34m▁[39m[39m [32m█[39m[39m [39m [39m [39m▁[39m▁[39m [39m [39m [39m [39m [39m▁[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▁[39m [39m 
  [39m█[39m▁[39m▁[39m▁

In [25]:
@benchmark compute_coh_mean1(batch_size, n_channels, n_freqs, n_pairs, pairs, tfr, psd_per_epoch, weights, small_norm)

BenchmarkTools.Trial: 14 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m352.709 ms[22m[39m … [35m395.128 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m3.27% … 0.89%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m360.145 ms               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m6.73%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m362.191 ms[22m[39m ± [32m  9.726 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m6.12% ± 1.79%

  [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m█[34m▂[39m[39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▅[39m▁[39m▁[39m▁

In [16]:
@benchmark compute_coh_mean3(batch_size, n_channels, n_freqs, n_pairs, tfr, psd_per_epoch, weights_squared, normalization)

[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32

BenchmarkTools.Trial: 14 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m346.278 ms[22m[39m … [35m404.716 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m8.50% … 9.63%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m378.857 ms               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m8.70%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m378.977 ms[22m[39m ± [32m 13.224 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m8.36% ± 2.20%

  [39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m█[39m [39m [39m▁[39m [39m▁[39m [39m [39m▁[34m▁[39m[32m▁[39m[39m [39m [39m [39m [39m [39m▁[39m▁[39m█[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▁[39m [39m 
  [39m█[39m▁[39m▁[39m▁

In [17]:
@benchmark compute_coh_mean4(batch_size, n_channels, n_freqs, n_pairs, tfr, psd_per_epoch, weights_squared, normalization)

[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32

BenchmarkTools.Trial: 18 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m271.669 ms[22m[39m … [35m323.999 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m5.68% … 0.68%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m281.630 ms               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m7.66%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m283.773 ms[22m[39m ± [32m 12.035 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m7.15% ± 1.72%

  [39m▁[39m█[39m▁[39m▁[39m [39m [39m [39m█[39m▁[39m [39m▁[34m [39m[39m▁[39m [32m▁[39m[39m [39m [39m▁[39m▁[39m [39m█[39m▁[39m [39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▁[39m [39m 
  [39m█[39m█[39m█[39m█

In [50]:
@benchmark compute_coh_mean5(batch_size, n_channels, n_freqs, n_pairs, tfr, psd_per_epoch, weights_squared, small_norm)

[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32

BenchmarkTools.Trial: 25 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m193.785 ms[22m[39m … [35m233.333 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m3.69% … 0.49%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m197.345 ms               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m5.57%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m200.032 ms[22m[39m ± [32m  7.997 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m5.23% ± 1.53%

  [39m [39m [39m [39m [39m█[34m▆[39m[39m▆[39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▄[39m▁[39m▄[39m▁

In [None]:
coherence = Array{Float64,4}(undef, batch_size, n_channels, n_channels, n_freqs)

function compute_coh6!(coherence::Array{Float64,4}, tfrs::Array{ComplexF64,5}, pairs::Vector{Tuple{Int64, Int64}}, psd_per_epoch::Array{Float64,4}, weights_squared::Array{Float64,3}, normalization::Array{Float64, 2})::Array{Float64,4}
    batch_size, n_channels, n_taps, n_freqs, n_times = size(tfr)
    n_pairs = length(pairs)

    nthreads = Threads.nthreads()
    temp_arrays = [Array{ComplexF64,3}(undef, n_taps, n_freqs, n_times) for _ in 1:nthreads]
    @showprogress desc = "Computing Coherence..." Threads.@threads for idx in 1:batch_size*n_pairs
        thread_id = Threads.threadid()
        temp = temp_arrays[thread_id]

        # Calculate the epoch index and pair index
        epoch_idx = div(idx - 1, n_pairs) + 1
        pair_idx = mod(idx - 1, n_pairs) + 1
        x, y = pairs[pair_idx]
        # println("Epoch: $epoch_idx, Pair: ($x, $y)")
        # Now perform your operations
        w_x = @view tfrs[epoch_idx, x, :, :, :]
        w_y = @view tfrs[epoch_idx, y, :, :, :]
        temp .= weights_squared .* w_x .* conj.(w_y)
        s_xy = dropdims(sum(temp, dims=1),dims=1)  # sum over tapers
        s_xy .*= normalization

        s_xx = @view psd_per_epoch[epoch_idx, x, :, :]
        s_yy = @view psd_per_epoch[epoch_idx, y, :, :]

        # Compute the numerator: absolute value of the mean of s_xy along the last dimension
        con_num = abs.(mean(s_xy, dims=2))

        # Compute the denominator: square root of the product of means of s_xx and s_yy along the last dimension
        con_den = sqrt.(mean(s_xx, dims=2) .* mean(s_yy, dims=2))

        # Calculate coherence as the element-wise division of numerator by denominator
        coh_value = con_num ./ con_den
        # coh_value = coh(s_xx, s_yy, s_xy[1, :, :])

        # Copy to symmetric position
        coherence[epoch_idx, y, x, :] .= coh_value
    end
    return coherence
end
compute_coh6!(coherence, tfr, pairs, psd_per_epoch, weights_squared, small_norm)

[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m


2×10×10×37 Array{Float64, 4}:
[:, :, 1, 1] =
 1.01227e-315  0.736191  0.416704  0.51118   …  0.440438  0.209318  0.241022
 1.05469e-315  0.926867  0.837006  0.728169     0.359982  0.294933  0.180178

[:, :, 2, 1] =
 0.0           6.36599e-314  0.686981  …  0.321598  0.335393   0.371932
 9.25881e-316  0.0           0.920761     0.239729  0.0920393  0.0825587

[:, :, 3, 1] =
 6.36599e-314  9.25882e-316  0.0           …  0.799185  0.720752  0.694042
 0.0           6.36599e-314  9.25882e-316     0.447055  0.167288  0.142126

[:, :, 4, 1] =
 9.25882e-316  0.0           6.36599e-314  …  0.774481  0.679359  0.618833
 6.36599e-314  9.25882e-316  0.0              0.490607  0.22802   0.234667

[:, :, 5, 1] =
 0.0           6.36599e-314  9.25883e-316  …  0.544387  0.482208  0.430485
 9.25883e-316  0.0           6.36599e-314     0.606839  0.362484  0.275902

[:, :, 6, 1] =
 6.36599e-314  9.25884e-316  0.0           …  0.573803  0.363882  0.26413
 0.0           6.36599e-314  9.25884e-316     0.6157

In [76]:
@benchmark compute_coh6!(coherence, tfr, pairs, psd_per_epoch, weights_squared, small_norm)

[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32mComputing Coherence... 100%|█████████████████████████████| Time: 0:00:00[39m
[32

BenchmarkTools.Trial: 26 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m189.896 ms[22m[39m … [35m217.084 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 0.00%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m195.821 ms               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m1.79%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m199.200 ms[22m[39m ± [32m  8.012 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m1.60% ± 0.84%

  [39m [39m [39m [39m [39m [39m [39m [39m▃[39m█[39m▃[39m▃[39m▃[34m [39m[39m [39m▃[39m▃[39m [39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▇[39m▁[39m▁[39m▁

In [55]:
function test(pairs::Array{Tuple{Int, Int},1}, tfrs::Array{ComplexF64,5}, psd_per_epoch::Array{Float64,4}, weights_squared::Array{Float64,3}, normalization::Array{Float64, 2})::Matrix{Float64}
    epoch_idx = 1
    pair_idx = 1

    x, y = pairs[pair_idx]
    # println("Epoch: $epoch_idx, Pair: ($x, $y)")
    # Now perform your operations
    w_x = @view tfrs[epoch_idx, x, :, :, :]
    w_y = @view tfrs[epoch_idx, y, :, :, :]
    s_xy =  dropdims(sum(weights_squared .* w_x .* conj.(w_y), dims=1),dims=1)  # sum over tapers
    s_xy .*= normalization


    s_xx = @view psd_per_epoch[epoch_idx, x, :, :]
    s_yy = @view psd_per_epoch[epoch_idx, y, :, :]

    # Compute the numerator: absolute value of the mean of s_xy along the last dimension
    con_num = abs.(mean(s_xy, dims=2))

    # Compute the denominator: square root of the product of means of s_xx and s_yy along the last dimension
    con_den = sqrt.(mean(s_xx, dims=2) .* mean(s_yy, dims=2))

    # Calculate coherence as the element-wise division of numerator by denominator
    coh_value = con_num ./ con_den
    # coh_value = coh(s_xx, s_yy, s_xy[1, :, :])


    return coh_value
end

test(pairs, tfr, psd_per_epoch, weights_squared, small_norm);
# coh_value = coh(s_xx, s_yy, s_xy[1, :, :])

# Copy to symmetric position
# coherence[epoch_idx, y, x, :] .= coh_value

In [54]:
@benchmark test(pairs, tfr, psd_per_epoch, weights_squared, small_norm)

BenchmarkTools.Trial: 2752 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m1.573 ms[22m[39m … [35m  7.603 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 76.85%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m1.642 ms               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m1.807 ms[22m[39m ± [32m454.427 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m6.36% ± 11.41%

  [39m [39m▆[39m█[34m▆[39m[39m▄[39m▁[39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▂[39m▃[39m▁[39m [39m [39m▁
  [39m▃[39m█[39m█[34m█[39m[39m█[3

In [None]:
function test(pairs::Array{Tuple{Int, Int},1}, tfrs::Array{ComplexF64,5}, psd_per_epoch::Array{Float64,4}, weights_squared::Array{Float64,3}, normalization::Array{Float64, 2})::Matrix{Float64}
    epoch_idx = 1
    pair_idx = 1

    x, y = pairs[pair_idx]
    # println("Epoch: $epoch_idx, Pair: ($x, $y)")
    # Now perform your operations
    w_x = @view tfrs[epoch_idx, x, :, :, :]
    w_y = @view tfrs[epoch_idx, y, :, :, :]
    s_xy =  dropdims(sum(weights_squared .* w_x .* conj.(w_y), dims=1),dims=1)  # sum over tapers
    s_xy .*= normalization


    s_xx = @view psd_per_epoch[epoch_idx, x, :, :]
    s_yy = @view psd_per_epoch[epoch_idx, y, :, :]

    # Compute the numerator: absolute value of the mean of s_xy along the last dimension
    con_num = abs.(mean(s_xy, dims=2))

    # Compute the denominator: square root of the product of means of s_xx and s_yy along the last dimension
    con_den = sqrt.(mean(s_xx, dims=2) .* mean(s_yy, dims=2))

    # Calculate coherence as the element-wise division of numerator by denominator
    coh_value = con_num ./ con_den
    # coh_value = coh(s_xx, s_yy, s_xy[1, :, :])


    return coh_value
end

test(pairs, tfr, psd_per_epoch, weights_squared, small_norm);
# coh_value = coh(s_xx, s_yy, s_xy[1, :, :])

# Copy to symmetric position
# coherence[epoch_idx, y, x, :] .= coh_value