In [None]:
using LinearAlgebra
using FFTW
using Statistics
using JLD2
using NPZ
using ProgressMeter
# using BenchmarkTools

function sinusoidal(a, f, sr, t, theta=0, DC=0)
    delta_i = 1 / sr
    f2pi = f * 2 * π
    nu = [DC + (a * sin(f2pi * i * delta_i + theta)) for i in 0:(t-1)]
    return nu
end

function tfr_estimate_size(n_epochs, n_channels, n_taps, n_freqs, n_times)
    element_size = 16 # bytes for ComplexF64
    total_elements = n_epochs * n_channels * n_taps * n_freqs * n_times
    total_bytes = total_elements * element_size
    total_gb = total_bytes / (1024^3) # Convert bytes to GB
    return total_gb
end

function weights_estimate_size(n_taps, n_freqs, n_times)
    element_size = 8 # bytes for Float64
    total_elements = n_taps * n_freqs * n_times
    total_bytes = total_elements * element_size
    total_gb = total_bytes / (1024^3) # Convert bytes to GB
    return total_gb
end

function Ws_estimate_size(n_taps, freqs, sfreq, n_cycles)
    element_size = 16 # bytes for ComplexF64
    total_elements = 0
    for k = 1:n_freqs
        f = freqs[k]
        t_win = n_cycles / f
        len_t = Int(ceil(t_win * sfreq))
        total_elements += n_taps * len_t
    end
    total_bytes = total_elements * element_size
    total_gb = total_bytes / (1024^3) # Convert bytes to GB
    return total_gb
end

function fft_estimate_size(a, b, c)
    element_size = 16 # bytes for ComplexF64
    total_elements = a * b * c
    total_bytes = total_elements * element_size
    total_gb = total_bytes / (1024^3) # Convert bytes to GB
    return total_gb
end

function psd_estimate_size(n_epochs, n_channels, n_freqs, n_times)
    element_size = 8 # bytes for Float64
    total_elements = n_epochs * n_channels * n_freqs * n_times
    total_bytes = total_elements * element_size
    total_gb = total_bytes / (1024^3) # Convert bytes to GB
    return total_gb
end

function coh_estimate_size(n_epochs, n_channels, n_freqs)
    element_size = 8 # bytes for Float64
    total_elements = n_epochs * n_channels * n_channels * n_freqs
    total_bytes = total_elements * element_size
    total_gb = total_bytes / (1024^3) # Convert bytes to GB
    return total_gb
end

function scale_dimensions(data, n_taps, freqs, sfreq, n_cycles; print=false, max_gb=0, reserve_gb=0)
    system_mem = Sys.total_memory() / (1024^3)
    if max_gb == 0
        max_gb =  system_mem- reserve_gb # Leave X GB for other stuff
    end
    current_mem = Sys.free_memory() / (1024^3)
    if current_mem <= max_gb
        @warn "Current free memory: $current_mem GB\nDesired max: $max_gb GB\nSystem total memory: $(system_mem) GB\nMemory available is less than desired max!\nAttempting to use available memory!"
        max_gb = current_mem - reserve_gb
    end
    if(max_gb <= 0)
        error("Not enough memory available!")
    end
    
    n_epoch_org, n_channels, n_times = size(data)

    n_freqs = length(freqs)
    t_win = n_cycles / minimum(freqs)
    max_len = Int(ceil(t_win * sfreq))
    nfft = n_times + max_len - 1
    nfft = next_fast_len(nfft)
    
    data_size = Base.summarysize(data) / (1024^3)
    weights = weights_estimate_size(n_taps, n_freqs, n_times)
    Ws = Ws_estimate_size(n_taps, freqs, sfreq, n_cycles)
    fft_Ws = fft_estimate_size(n_taps, n_freqs, nfft)
    fft_X = fft_estimate_size(n_epoch_org, n_channels, nfft)
    coherence_mean = coh_estimate_size(n_epoch_org, n_channels, 1)
    
    current_mem = Sys.free_memory() / (1024^3)
    
    static_total = data_size + weights + Ws + fft_Ws + fft_X + coherence_mean
    
    if static_total >= max_gb || static_total >= current_mem
        println("Static calculations will exceed memory!")
        println("---------------------------------")
        println("Current free memory: $current_mem GB")
        println("---------------------------------")
        println("Data: $(data_size) GB")
        println("Weights: $weights GB")
        println("Ws: $Ws GB")
        println("FFT Ws: $fft_Ws GB")
        println("FFT X: $fft_X GB")
        println("Coherence Mean: $coherence_mean GB")
        println("---------------------------------")
        println("Total: $static_total GB")
        println("Exceeds maximum memory limit of $max_gb GB or current free memory")
        error("Memory limit exceeded")
    end
    n_epochs = copy(n_epoch_org)
    tfr_size = tfr_estimate_size(n_epochs, n_channels, n_taps, n_freqs, n_times)
    psd_per_epoch = psd_estimate_size(n_epochs, n_channels, n_freqs, n_times)
    coherence = coh_estimate_size(n_epoch_org, n_channels, n_freqs)
    coherence_mean_small = 0
    
    dynamic_total = tfr_size + psd_per_epoch + coherence + coherence_mean_small    
    while dynamic_total + static_total >= max_gb && n_epochs > 0
        n_epochs -= 1
        tfr_size = tfr_estimate_size(n_epochs, n_channels, n_taps, n_freqs, n_times)
        psd_per_epoch = psd_estimate_size(n_epochs, n_channels, n_freqs, n_times)
        coherence = coh_estimate_size(n_epochs, n_channels, n_freqs)
        coherence_mean_small = coh_estimate_size(n_epochs, n_channels, 1)
        dynamic_total = tfr_size + psd_per_epoch + coherence + coherence_mean_small
    end

    if n_epochs == 0
        current_mem = Sys.free_memory() / (1024^3)
        println("Can not even compute one epoch with current memory!")
        println("---------------------------------")
        println("Current free memory: $current_mem GB")
        println("System total memory: $system_mem GB")
        println("--------------Static arrays--------------")
        println("Data: $(data_size) GB")
        println("Weights: $weights GB")
        println("Ws: $Ws GB")
        println("FFT Ws: $fft_Ws GB")
        println("FFT X: $fft_X GB")
        println("Coherence Mean: $coherence_mean GB")
        println("Static Total: $static_total GB")
        println("-----Dynamically calculated arrays-----")
        println("TFR: $tfr_size GB")
        println("PSD: $psd_per_epoch GB")
        println("Coherence: $coherence GB")
        println("Coherence Mean (small): $coherence_mean_small GB")
        println("Dynamic Total: $dynamic_total GB")
        println("---------------------------------")
        println("Total: $dynamic_total + $static_total GB")
        println("Exceeds maximum memory limit of $max_gb GB")
        error("Memory limit exceeded")
    end



    if print
        current_mem = Sys.free_memory() / (1024^3)
        println("Can be computed with batches of $n_epochs epochs")
        println("total batches: $(ceil(n_epoch_org/n_epochs))")
        println("--------------Static arrays--------------")
        println("Data: $(data_size) GB")
        println("Weights: $weights GB")
        println("Ws: $Ws GB")
        println("FFT Ws: $fft_Ws GB")
        println("FFT X: $fft_X GB")
        println("Coherence Mean: $coherence_mean GB")
        println("Static Total: $static_total GB")
        println("-----Dynamically calculated arrays-----")
        println("TFR: $tfr_size GB")
        println("PSD: $psd_per_epoch GB")
        println("Coherence: $coherence GB")
        println("Coherence Mean (small): $coherence_mean_small GB")
        println("Dynamic Total: $dynamic_total GB")
        println("---------------------------------")
        println("Total: $(dynamic_total+static_total) GB")
        println("Desired max: $max_gb GB")
        println("---------------------------------")        
        println("Current free memory: $current_mem GB")
        println("System total memory: $system_mem GB")
    end
    return n_epochs
end


function tril_indices(n)::Array{Tuple{Int,Int},1}
    pairs = Array{Tuple{Int,Int},1}(undef, n * (n - 1) ÷ 2)
    q = 1
    for x in 1:n
        for y in (x+1):n
            pairs[q] = (x, y)
            q += 1
        end
    end
    return pairs
end


function next_fast_len(target::Int)::Int
    """
    Find the next fast size of input data to `fft`, for zero-padding, etc.

    Returns the next composite of the prime factors 2, 3, and 5 which is
    greater than or equal to `target`. (These are also known as 5-smooth
    numbers, regular numbers, or Hamming numbers.)

    Parameters
    ----------
    target : Int
        Length to start searching from. Must be a positive integer.

    Returns
    -------
    out : Int
        The first 5-smooth number greater than or equal to `target`.
    """
    # Precomputed Hamming numbers (5-smooth numbers) for quick lookup
    hams = [
        8, 9, 10, 12, 15, 16, 18, 20, 24, 25, 27, 30, 32, 36, 40, 45, 48, 50,
        54, 60, 64, 72, 75, 80, 81, 90, 96, 100, 108, 120, 125, 128, 135, 144,
        150, 160, 162, 180, 192, 200, 216, 225, 240, 243, 250, 256, 270, 288,
        300, 320, 324, 360, 375, 384, 400, 405, 432, 450, 480, 486, 500, 512,
        540, 576, 600, 625, 640, 648, 675, 720, 729, 750, 768, 800, 810, 864,
        900, 960, 972, 1000, 1024, 1080, 1125, 1152, 1200, 1215, 1250, 1280,
        1296, 1350, 1440, 1458, 1500, 1536, 1600, 1620, 1728, 1800, 1875, 1920,
        1944, 2000, 2025, 2048, 2160, 2187, 2250, 2304, 2400, 2430, 2500, 2560,
        2592, 2700, 2880, 2916, 3000, 3072, 3125, 3200, 3240, 3375, 3456, 3600,
        3645, 3750, 3840, 3888, 4000, 4050, 4096, 4320, 4374, 4500, 4608, 4800,
        4860, 5000, 5120, 5184, 5400, 5625, 5760, 5832, 6000, 6075, 6144, 6250,
        6400, 6480, 6561, 6750, 6912, 7200, 7290, 7500, 7680, 7776, 8000, 8100,
        8192, 8640, 8748, 9000, 9216, 9375, 9600, 9720, 10000
    ]

    if target <= 6
        return target
    end

    # Check if target is already a power of 2
    if (target & (target - 1)) == 0
        return target
    end

    # Quick lookup for small sizes
    if target <= hams[end]
        idx = searchsortedfirst(hams, target)
        return hams[idx]
    end

    # Function to compute the bit length of an integer
    bit_length(x::Int) = x <= 0 ? 0 : floor(Int, log2(x)) + 1

    match = typemax(Int)  # Initialize with maximum possible integer
    p5 = 1
    while p5 < target
        p35 = p5
        while p35 < target
            # Ceiling integer division
            quotient = cld(target, p35)
            p2 = 2^bit_length(quotient - 1)
            N = p2 * p35
            if N == target
                return N
            elseif N < match
                match = N
            end
            p35 *= 3
            if p35 == target
                return p35
            end
        end
        if p35 < match
            match = p35
        end
        p5 *= 5
        if p5 == target
            return p5
        end
    end
    if p5 < match
        match = p5
    end
    return match
end


outputpath = "/media/dan/Data/git/network_mining/connectivity/julia_test/"
data = npzread("/media/dan/Data/git/network_mining/connectivity/julia_test/034_input.npy")
data = data[1:2, 1:10, :]

sfreq = 2048
freqs = collect(14:50)
zero_mean = true
n_freqs = length(freqs)
mt_bandwidth = 4
n_taps = floor(Int, mt_bandwidth - 1)
n_cycles = 7
n_epochs, n_channels, n_times = size(data)

batch_size = scale_dimensions(data, n_taps, freqs, sfreq, n_cycles, print=true, reserve_gb=60)
total_batches = ceil(n_epochs / batch_size)
if batch_size != n_epochs
    println("Data is too big for one pass!\nData will be computed in batches of $batch_size epochs. Total batches: $(total_batches)")
end


println("Making tapers...")
Ws, weights = compute_tapers(n_times, n_taps, freqs, mt_bandwidth, n_cycles, sfreq)
weights_squared = weights .^ 2
normalization = 2 ./ sum(real(weights .* conj(weights)), dims=1);
small_norm = dropdims(normalization; dims=1)

nfft = _get_nfft(Ws, data)

println("Precomputing FFTs of tapers and data...")
fft_Ws = precompute_fft_Ws(Ws, nfft);
fft_X = precompute_fft_X(data, nfft);
println("Done!")

# save(joinpath(outputpath, "034_pretasks.jld2"), "Ws", Ws, "weights", weights, "fft_Ws", fft_Ws, "fft_X", fft_X, "normalization", normalization)

Ws_lengths = [length(Wk) for Wk in Ws]

println("Preparing for computation...")
pairs = tril_indices(n_channels)
n_pairs = length(pairs)


tfr = Array{ComplexF64,5}(undef, batch_size, n_channels, n_taps, n_freqs, n_times);
compute_tfr!(tfr, fft_X, fft_Ws, Ws_lengths);

psd_per_epoch = Array{Float64,4}(undef, batch_size, n_channels, n_freqs, n_times);
compute_psd!(psd_per_epoch, tfr, weights, small_norm);


In [80]:

sfreq = 2048
freqs = collect(14:250)
zero_mean = true
n_freqs = length(freqs)
mt_bandwidth = 4
n_taps = floor(Int, mt_bandwidth - 1)
n_cycles = 7
n_epochs, n_channels, n_times = size(data)

batch_size = scale_dimensions(data, n_taps, freqs, sfreq, n_cycles, print=true, reserve_gb=60)
total_batches = ceil(n_epochs / batch_size)
if batch_size != n_epochs
    println("Data is too big for one pass!\nData will be computed in batches of $batch_size epochs. Total batches: $(total_batches)")
end



Can be computed with batches of 152 epochs
total batches: 4.0
--------------Static arrays--------------
Data: 0.4484711214900017 GB
Weights: 0.00542449951171875 GB
Ws: 0.0018763840198516846 GB
FFT Ws: 0.021697998046875 GB
FFT X: 1.79388427734375 GB
Coherence Mean: 0.04248212277889252 GB
Static Total: 2.3138364031910896 GB
-----Dynamically calculated arrays-----
TFR: 159.9576416015625 GB
PSD: 26.65960693359375 GB
Coherence: 2.5253729224205017 GB
Coherence Mean (small): 0.010655581951141357 GB
Dynamic Total: 189.1532770395279 GB
---------------------------------
Total: 191.46711344271898 GB
Desired max: 191.5341453552246 GB
---------------------------------
Current free memory: 237.98957443237305 GB
System total memory: 251.5341453552246 GB
Data is too big for one pass!
Data will be computed in batches of 152 epochs. Total batches: 4.0
