In [1]:
using LinearAlgebra, FFTW, Plots

function sinusoidal(a, f, sr, t, theta=0, DC=0)
    delta_i = 1 / sr
    f2pi = f * 2 * π
    nu = [DC + (a * sin(f2pi * i * delta_i + theta)) for i in 0:(t-1)]
    return nu
end
function next_fast_len(target::Int)::Int
    """
    Find the next fast size of input data to `fft`, for zero-padding, etc.

    Returns the next composite of the prime factors 2, 3, and 5 which is
    greater than or equal to `target`. (These are also known as 5-smooth
    numbers, regular numbers, or Hamming numbers.)

    Parameters
    ----------
    target : Int
        Length to start searching from. Must be a positive integer.

    Returns
    -------
    out : Int
        The first 5-smooth number greater than or equal to `target`.
    """
    # Precomputed Hamming numbers (5-smooth numbers) for quick lookup
    hams = [
        8, 9, 10, 12, 15, 16, 18, 20, 24, 25, 27, 30, 32, 36, 40, 45, 48, 50,
        54, 60, 64, 72, 75, 80, 81, 90, 96, 100, 108, 120, 125, 128, 135, 144,
        150, 160, 162, 180, 192, 200, 216, 225, 240, 243, 250, 256, 270, 288,
        300, 320, 324, 360, 375, 384, 400, 405, 432, 450, 480, 486, 500, 512,
        540, 576, 600, 625, 640, 648, 675, 720, 729, 750, 768, 800, 810, 864,
        900, 960, 972, 1000, 1024, 1080, 1125, 1152, 1200, 1215, 1250, 1280,
        1296, 1350, 1440, 1458, 1500, 1536, 1600, 1620, 1728, 1800, 1875, 1920,
        1944, 2000, 2025, 2048, 2160, 2187, 2250, 2304, 2400, 2430, 2500, 2560,
        2592, 2700, 2880, 2916, 3000, 3072, 3125, 3200, 3240, 3375, 3456, 3600,
        3645, 3750, 3840, 3888, 4000, 4050, 4096, 4320, 4374, 4500, 4608, 4800,
        4860, 5000, 5120, 5184, 5400, 5625, 5760, 5832, 6000, 6075, 6144, 6250,
        6400, 6480, 6561, 6750, 6912, 7200, 7290, 7500, 7680, 7776, 8000, 8100,
        8192, 8640, 8748, 9000, 9216, 9375, 9600, 9720, 10000
    ]

    if target <= 6
        return target
    end

    # Check if target is already a power of 2
    if (target & (target - 1)) == 0
        return target
    end

    # Quick lookup for small sizes
    if target <= hams[end]
        idx = searchsortedfirst(hams, target)
        return hams[idx]
    end

    # Function to compute the bit length of an integer
    bit_length(x::Int) = x <= 0 ? 0 : floor(Int, log2(x)) + 1

    match = typemax(Int)  # Initialize with maximum possible integer
    p5 = 1
    while p5 < target
        p35 = p5
        while p35 < target
            # Ceiling integer division
            quotient = cld(target, p35)
            p2 = 2^bit_length(quotient - 1)
            N = p2 * p35
            if N == target
                return N
            elseif N < match
                match = N
            end
            p35 *= 3
            if p35 == target
                return p35
            end
        end
        if p35 < match
            match = p35
        end
        p5 *= 5
        if p5 == target
            return p5
        end
    end
    if p5 < match
        match = p5
    end
    return match
end

function _extend(M::Int, sym::Bool)::Tuple{Int, Bool}
    # Extend window by 1 sample if needed for DFT-even symmetry
    if !sym
        return M + 1, true
    else
        return M, false
    end
end


function _fftautocorr(x::AbstractMatrix{<:Float64})::Array{Float64, 2}
    """
    tested vs python:
    isapprox(x_fft, py_x_fft, atol=1e-12) == true
    isapprox(py_cxy, cxy, atol=1e-12) == true
    """
    N = size(x, 2)
    use_N = next_fast_len(2 * N - 1)
    padded = zeros(Float64, size(x,1), use_N)
    padded[:, 1:N] .= x
    plan = plan_rfft(padded, 2)
    x_fft = plan * padded
    cxy = irfft(x_fft .* conj.(x_fft), use_N, 2)[:, 1:N]
    return cxy
end


function py_dpss(M::Int, NW::Float64; Kmax::Union{Int,Nothing}=nothing, sym::Bool=true, norm::Union{Int,String,Nothing}=nothing, return_ratios::Bool=false)
    """
    Compute the Discrete Prolate Spheroidal Sequences (DPSS).

    Parameters
    ----------
    M : Int
        Window length.
    NW : Float64
        Standardized half bandwidth corresponding to 2*NW = BW/f0 = BW*M*dt
        where dt is taken as 1.
    Kmax : Int or Nothing, optional
        Number of DPSS windows to return (orders 0 through Kmax-1).
        If nothing (default), return only a single window of shape (M,)
        instead of an array of windows of shape (Kmax, M).
    sym : Bool, optional
        When true (default), generates a symmetric window, for use in filter design.
        When false, generates a periodic window, for use in spectral analysis.
    norm : Int, String, or Nothing, optional
        If "approximate" or "subsample", then the windows are normalized by the
        maximum, and a correction scale-factor for even-length windows
        is applied either using M^2/(M^2+NW) ("approximate") or
        a FFT-based subsample shift ("subsample").
        If nothing, then "approximate" is used when Kmax=nothing and 2 otherwise
        (which uses the l2 norm).
    return_ratios : Bool, optional
        If true, also return the concentration ratios in addition to the windows.

    Returns
    -------
    windows : Array{Float64, 2} or Array{Float64, 1}
        The DPSS windows. Will be 1D if `Kmax` is nothing.
    ratios : Array{Float64, 1} or Float64, optional
        The concentration ratios for the windows. Only returned if
        `return_ratios` evaluates to true. Will be scalar if `Kmax` is nothing.
    """
    if norm === nothing
        norm = Kmax === nothing ? "approximate" : 2
    end
    known_norms = [2, "approximate", "subsample"]
    if !(norm in known_norms)
        error("norm must be one of $known_norms, got $norm")
    end
    if Kmax === nothing
        singleton = true
        Kmax = 1
    else
        singleton = false
    end
    Kmax = Int(Kmax)
    if !(0 < Kmax <= M)
        error("Kmax must be greater than 0 and less than M")
    end
    if NW >= M / 2.0
        error("NW must be less than M/2.")
    end
    if NW <= 0
        error("NW must be positive")
    end
    M, needs_trunc = _extend(M, sym)
    W = NW / M
    nidx = collect(0:M-1)
    d = ((M - 1 .- 2 .* nidx) ./ 2.0) .^ 2 .* cos.(2pi * W)
    e = nidx[2:end] .* (M .- nidx[2:end]) ./ 2.0
    T = Tridiagonal(e, d, e)
    eigs = eigen(T)
    eigvals = eigs.values
    eigvecs = eigs.vectors
    indices = M-Kmax+1:M
    eigvals = eigvals[indices]
    eigvals = reverse(eigvals)
    eigvecs = eigvecs[end:-1:1, indices];
    windows = eigvecs[:, end:-1:1]' 
    # Correct the sign conventions
    fix_even = sum(windows[1:2:end, :], dims=2) .< 0
    for (i, f) in enumerate(fix_even)
        if f[1]
            windows[2i-1, :] *= -1
        end
    end
    thresh = max(1e-7, 1.0 / M)
    for (i, w) in enumerate(eachrow(windows[2:2:end, :]))
        idx = findfirst(x -> x^2 > thresh, w)
        if idx !== nothing && w[idx] < 0
            windows[2i, :] *= -1
        end
    end

    if return_ratios
        dpss_rxx = _fftautocorr(windows)
        r = 4 * W * sinc.(2 * W .* nidx)
        r[1] = 2 * W
        ratios = dpss_rxx * r
        if singleton
            ratios = ratios[1]
        end
    end

    if norm != 2 
        # checked vs python and it is the same
        windows .= windows ./ maximum(abs.(windows))
        if iseven(M)
            if norm == "approximate"
                correction = M^2 / (M^2 + NW)
            else
                s = rfft(windows[1, :])
                shift = -(1 - 1.0 / M) .* collect(1:Int(M / 2))
                s[2:end] .= s[2:end] .* (2 .* exp.(-im * π .* shift))
                correction = M / sum(real.(s))
            end
            windows .= windows .* correction
        end
    end
    if needs_trunc
        windows = windows[:, 1:end-1]
    end
    if singleton
        windows = windows[1, :]
    end
    if return_ratios
        return windows, ratios
    else
        return windows
    end
end


function new_py_dpss(M::Int, NW::Float64, normalization_type::Int, Kmax::Int; sym::Bool=true)::Tuple{Array{Float64, 2}, Union{Array{Float64, 1}, Float64}}
    # Validate inputs
    known_norms = (1, 2, 3)
    if normalization_type ∉ known_norms
        error("normalization_type must be one of $known_norms, got $normalization_type")
    end
    if Kmax === 1
        singleton = true
    else
        singleton = false
    end
    if !(0 < Kmax <= M)
        error("Kmax must be greater than 0 and less than or equal to M")
    end
    if NW >= M / 2.0
        error("NW must be less than M/2.")
    end
    if NW <= 0
        error("NW must be positive")
    end

    M, needs_trunc = _extend(M, sym)
    W = NW / M
    nidx = collect(0:M-1)
    d = ((M - 1 .- 2 .* nidx) ./ 2.0) .^ 2 .* cos.(2pi * W)
    e = nidx[2:end] .* (M .- nidx[2:end]) ./ 2.0
    # Use SymTridiagonal for efficient eigenvalue computation
    T = SymTridiagonal(d, e)
    eigs = eigen(T)
    # Extract the largest Kmax eigenvalues and eigenvectors
    indices = M-Kmax+1:M
    eigvals = eigs.values[indices]
    windows = eigs.vectors[end:-1:1, indices]
    windows = windows[:, end:-1:1]'
    # Correct sign conventions
    fix_even = sum(windows[1:2:end, :], dims=2) .< 0
    windows[1:2:end, :][fix_even[:,1], :] .*= -1

    # # Correct signs for even-indexed windows
    thresh = max(1e-7, 1.0 / M)
    for (i, w) in enumerate(eachrow(windows[2:2:end, :]))
        idx = findfirst(x -> x^2 > thresh, w)
        if idx !== nothing && w[idx] < 0
            windows[2i, :] *= -1
        end
    end

    # Compute concentration ratios
    dpss_rxx = _fftautocorr(windows)
    r = 4 * W * sinc.(2 * W .* (nidx))
    r[1] = 2 * W
    ratios = dpss_rxx * r
    
    if singleton
        ratios = ratios[1]
    end
    # Apply normalization if needed
    if normalization_type != 1
        max_abs = maximum(abs, windows)
        windows ./= max_abs
        if iseven(M)
            if normalization_type == 2
                correction = M^2 / (M^2 + NW)
            elseif normalization_type == 3
                s = rfft(windows[1, :])
                shift = -(1 - 1.0 / M) .* (1:Int(M / 2))
                s[2:end] .*= 2 .* exp.(-im * π .* shift)
                correction = M / sum(real(s))
            end
            windows .*= correction
        end
    end
    
    if needs_trunc
        windows = windows[:, 1:end-1]
    end
    if singleton
        windows = windows[1, :]
    end
    return windows, ratios
end

M=112
NW=2.0
Kmax=3
sym=false



dpss_out, eigenvals = py_dpss(M, NW, Kmax=Kmax, sym=sym, norm=2, return_ratios=true)

dpss_out_new, eigenvals_new = new_py_dpss(M, NW, 1, Kmax, sym=sym)
dpss_out == dpss_out_new, eigenvals == eigenvals_new

(true, true)