# EMG Decomposition Tutorial in Julia - ISB 2025

**Author:** Simon Avrillon - Nantes Université

This tutorial demonstrates how to perform EMG signal decomposition using Julia, including preprocessing, decomposition, and visualisation steps.

## Installing Julia

1. **Download Julia:**  
    Visit the [official Julia website](https://julialang.org/downloads/) and download the installer for your operating system.

2. **Install Julia:**  
    Run the downloaded installer and follow the on-screen instructions.

3. **Add Julia to PATH (optional):**  
    For easier access from the terminal, ensure Julia is added to your system's PATH variable.

4. **Verify Installation:**  
    Open a terminal or command prompt and type:
    ```
    julia --version
    ```
    You should see the installed Julia version printed.

5. **Install Required Packages:**  
    Run the first cell of this notebook to install all the required packages

## Installing Python and the `mne` Package

Some parts of this tutorial require Python and the `mne` package (for reading EDF files).

1. **Install Python:**  
    Download Python from the [official website](https://www.python.org/downloads/) and follow the installation instructions for your OS.

2. **Verify Python Installation:**  
    Open a terminal or command prompt and type:
    ```
    python --version
    ```
    or
    ```
    python3 --version
    ```
    You should see the installed Python version printed.

3. **Install the `mne` Package:**  
    In your terminal or command prompt, run:
    ```
    pip install mne
    ```
    or
    ```
    pip3 install mne
    ```

4. **Check Installation:**  
    In Python, try importing `mne`:
    ```python
    import mne
    print(mne.__version__)
    ```

You're now ready to run Julia code and follow along with this EMG decomposition tutorial!

## 1. Installing Julia packages

In [None]:
using Pkg

# List of required packages
required_pkgs = [
    "JLD2", "PyCall", "CSV", "DataFrames", "DSP", "FFTW", "Statistics",
    "LinearAlgebra", "Clustering", "SignalAnalysis", "FindPeaks1D", "PlotlyJS",
    "StatsBase"
]

for pkg in required_pkgs
    if !haskey(Pkg.installed(), pkg)
        Pkg.add(pkg)
    end
end

## 2. Load the data
Change the path to the edf file containing the EMG data


In [None]:
using PyCall

# Import the mne module
mne = pyimport("mne")

# Define your file path (use raw string to avoid escaping issues)
edf_path = raw"/Users/savrillo/Dev/MUedit-julia/data/sub-01_task-isometric30percentmvc_run-01_emg.edf"

# Load the EDF file
raw = mne.io.read_raw_edf(edf_path, preload=true)

data = raw.get_data() # Get the EMG data

In [None]:
using CSV
using DataFrames

"""
Read EMG and auxiliary data from a TSV file and a data matrix.

# Arguments
- `path::String`: Path to the TSV metadata file.
- `data_matrix::AbstractMatrix{<:Real}`: Matrix of size (channels × samples) with the raw signal data.

# Returns
A `Dict` with:
- `"n_grids"`: Number of unique EMG grids
- `"emg_per_grid"`: Dict of grid name → EMG data matrix (channels × samples)
- `"aux_data"`: Auxiliary data matrix (channels × samples)
- `"aux_names"`: Names/descriptions of auxiliary channels
- `"sampling_frequency"`: Sampling rate in Hz
- `"channel_mask"`: Dict of grid name → BitVector indicating "good" channels
"""
function read_emg_tsv(path::String, data_matrix::Matrix{Float64})
    df = CSV.read(path, DataFrame, delim='\t')

    # Ensure single sampling frequency
    sampling_freqs = unique(df.sampling_frequency)
    if length(sampling_freqs) != 1
        error("Multiple sampling frequencies found: $sampling_freqs")
    end
    sampling_frequency = sampling_freqs[1]

    # Index channels by type
    emg_idx = findall(df.type .== "EMG")
    aux_idx = findall(df.type .== "MISC")

    emg_data_matrix = data_matrix[emg_idx, :]
    aux_data_matrix = data_matrix[aux_idx, :]
    aux_names = df.description[aux_idx]

    # Prepare group and status arrays for EMG channels
    emg_groups = df.group[emg_idx]
    emg_status = df.status[emg_idx]

    # Initialize output structures
    emg_per_grid = Dict{Any, Matrix}()
    channel_mask = Dict{Any, BitVector}()

    # Group EMG data by grid
    for grid in unique(emg_groups)
        grid_idx = findall(emg_groups .== grid)
        emg_per_grid[grid] = emg_data_matrix[grid_idx, :]

        status_vec = emg_status[grid_idx]
        channel_mask[grid] = BitVector(map(x -> x == "good", status_vec))
    end

    return Dict(
        "n_grids" => length(unique(emg_groups)),
        "emg_per_grid" => emg_per_grid,
        "aux_data" => aux_data_matrix,
        "aux_names" => aux_names,
        "sampling_frequency" => sampling_frequency,
        "channel_mask" => channel_mask
    )
end

In [None]:
path = "/Users/savrillo/Dev/MUedit-julia/data/sub-01_task-isometric30percentmvc_run-01_channels.tsv"
signal = read_emg_tsv(path, data)

## 3. Plot raw data

In [None]:
using PlotlyJS
# Plot the force data with the target displayed to the participant. The task was an isometric contraction with a trapezoidal force profile.
timeforce = range(0, size(signal["aux_data"],2)/signal["sampling_frequency"]; length=size(signal["aux_data"],2))

plot_data = [
    PlotlyJS.scatter(x=timeforce, y=signal["aux_data"][1,:], mode="lines", name="path"),
    PlotlyJS.scatter(x=timeforce, y=signal["aux_data"][2,:], mode="lines", name="target")
]

layout = PlotlyJS.Layout(title="Force data", xaxis_title="Time (s)", yaxis_title="Relative force (%)")

PlotlyJS.plot(plot_data, layout)

In [None]:
# Plot the five first channels of EMG data (grid 3) with an offset
offset = 2  
nchannels = 5
plot_data = [
    PlotlyJS.scatter(
        x=timeforce,
        y=signal["emg_per_grid"]["Grid3"][i, :] .+ (i-1)*offset,
        mode="lines",
        name="channel $i"
    ) for i in 1:nchannels
]

layout = PlotlyJS.Layout(
    title="EMG data (5 successive channels with offset)",
    xaxis_title="Time (s)",
    yaxis_title="EMG amplitude + offset"
)

PlotlyJS.plot(plot_data, layout)

## 4. Load the functions for EMG preprocessing

In [None]:
using DSP
using FFTW
using Statistics
using LinearAlgebra
using Clustering
using SignalAnalysis
using FindPeaks1D

"""
To remove the frequency components of the signal above a threshold
(median of the frequency power + 5 standard deviations.

# Arguments
    signal: row-wise signal
    fsamp: sampling frequency

# Returns
    filteredsignal: row-wise filtered signal
"""

function notchemg(signal, fsamp)
    bandwidth_as_index = round(Int, 4 * (size(signal, 2) / fsamp))
    filtered_signal = zeros(size(signal))

    for i in axes(signal, 1)
        final_signal = signal[i, :]

        # Perform FFT
        fourier_signal = fft(final_signal)
        fourier_interf = zeros(Complex{Float64}, length(fourier_signal))
        interf2remove = Vector{Int}()

        # Iterate over windowsof 1s of the signal
        for interval in 1:fsamp:(length(fourier_signal) - fsamp)
            segment = abs.(fourier_signal[interval+1:interval+fsamp])
            median_freq = Statistics.sort(segment)[div(length(segment)+1, 2)]
            std_freq = std(segment)

            # Find interference frequencies
            label_interf = findall(x -> x > median_freq + 5 * std_freq, segment)
            label_interf = [x + interval for x in label_interf]
            
            # Add bandwidth around interference
            for i in -div(bandwidth_as_index, 2):div(bandwidth_as_index, 2)
                append!(interf2remove, x + i for x in label_interf)
            end
        end
                
        # Filter valid indices
        indexf2remove = unique(filter(x -> x >= 1 && x <= length(fourier_signal) ÷ 2, interf2remove))
        fourier_interf[indexf2remove] .= fourier_signal[indexf2remove]

        # Mirror frequencies for IFFT
        corrector = mod(length(fourier_signal), 2)
        fourier_interf[Int(ceil(length(fourier_signal)/2))+1:end] .= conj.(reverse(fourier_interf[2:Int(ceil(length(fourier_signal)/2))+1 - corrector]))

        # Subtract interference from the original signal
        filtered_signal[i, :] .= signal[i, :] .- real(ifft(fourier_interf))
    end

    return filtered_signal
end

"""
Band pass filter 

# Arguments
    signal: row-wise signal
    fsamp: sampling frequency
    emgtype: 1 = surface, 2 = intra

# Returns
    filteredsignal: row-wise filtered signal
"""

function bandpassemg(signal, fsamp, emgtype)
    responsetype, designmethod = if emgtype == 1
        (Bandpass(20 / (fsamp /2), 500/(fsamp/2)), Butterworth(2)) # Surface EMG
    else
        (Bandpass(100 / (fsamp /2), 4400/(fsamp/2)), Butterworth(4))
    end
    filteredsignal = filtfilt(digitalfilter(responsetype, designmethod), signal')
    
    return collect(filteredsignal')
end

"""
To remove the mean of the signal

# Arguments
    signal: row-wise emg signal

# Returns
    demsignal: centered row-wise emg signal
"""

function demeanemg(signal)
    return signal .- mean(signal, dims=2)
end

"""
To extend the signal to reach the nb of extended channels (1000 in Negro 2016, 
can be higher to improve the decomposition)

# Arguments
    eY: row-wise emg signal
    exfactor: extension factor

# Returns
    esample: extended signal
"""

function extendemg(signal, exfactor)
    rows, cols = size(signal)
    esignal = zeros(rows * exfactor, cols + exfactor - 1)
    for i in 1:exfactor
        esignal[(i-1)*rows+1:i*rows, i:cols+i-1] .= signal
    end    
    
    return esignal
end

"""
Withening the EMG signal

# Arguments
    signal: row-wise signal

# Returns
    whitensignals = whitened EMG signal
    whiteningMatrix = whitening Matrix
    dewhiteningMatrix = dewhitening Matrix
"""

function whitenemg(signal)

    # Calculate the covariance matrix and get the eigenvalues and eigenvectors, eigenvalues are already sorted in the ascending direction
    covariance_matrix = cov(signal, dims=2)
    eigenDecomp = eigen(covariance_matrix)

    # Estimate the regularization factor as in Negro et al., 2016
    rankTolerance = max(0, mean(eigenDecomp.values[1:(length(eigenDecomp.values) ÷ 2)]))
    maxLastEig = sum(eigenDecomp.values .> rankTolerance)
    if maxLastEig < size(signal, 1)
        lowerLimitValue = (eigenDecomp.values[maxLastEig] + eigenDecomp.values[maxLastEig+1]) / 2;
    end
    
    # Select the columns corresponding to the desired range of eigenvalues.
    evectors = eigenDecomp.vectors[:, eigenDecomp.values .>lowerLimitValue]
    evalues = Diagonal(eigenDecomp.values[eigenDecomp.values .>lowerLimitValue])
    
    # whitening
    whiteningMatrix = evectors * inv(sqrt.(evalues)) * evectors'
    dewhiteningMatrix = evectors * sqrt.(evalues) * evectors'
    whitenedsignal =  whiteningMatrix * signal

    return whiteningMatrix, dewhiteningMatrix, whitenedsignal
end

"""
Main function to preprocess EMG signal before the decomposition 

# Arguments
    signal: row-wise signal
    fsamp: sampling frequency
    emgtype: 1 = surface, 2 = intra

# Returns
    whitensignals = whitened EMG signal
    whiteningMatrix = whitening Matrix
    dewhiteningMatrix = dewhitening Matrix
"""

function preprocessEMG(signal, fsamp, emgtype)
    signal = notchemg(signal, fsamp)
    signal = bandpassemg(signal, fsamp, emgtype)
    signal = demeanemg(signal)
    esignal = extendemg(signal, round(Int, 1000/size(signal,1)))
    whiteningMatrix, dewhiteningMatrix, whitenedsignal = whitenemg(esignal)

    return whiteningMatrix, dewhiteningMatrix, whitenedsignal
end

## 5. EMG preprocessing

In [None]:
emgtype = 1 # 1 = surface, 2 = intra
whiteningMatrix, dewhiteningMatrix, whitenedsignal = preprocessEMG(signal["emg_per_grid"]["Grid3"], Int(signal["sampling_frequency"]), emgtype)

## 6. Load the fonctions for EMG decomposition

In [None]:
"""
Fixed point algorithm to iteratively optimize a set of weights (MU
filter) to maximize the sparseness of the source (MU pulse train)

# Arguments
    w = separation vector
    Z = whitened signal (residual)
    B = separation matrix
    maxiter = maximal number of iteration before convergence

# Returns
   w = weigths (MU filter)
"""

function fixedpointalg(w, Z, B, maxiter)
    k = 1
    delta = ones(Float64,maxiter)
    TOL = 0.0001  # tolerance between two iterations
    BBT = B * B'
    wlast = w
    wTZ = wlast' * Z
    nsamp = size(Z,2)
    
    while delta[k] > TOL && k < maxiter
        # Save last weights
        wlast = w
        # Contrast function
        wTZ = wlast' * Z
        A = mean(2 * wTZ)
        w = Z * transpose(wTZ.^2) / nsamp .- A * w

        # Orthogonalisation
        w .-= BBT * w
        # Normalisation
        w /= norm(w)

        # Update convergence criteria
        k += 1
        delta[k] = only(abs.(dot(w, wlast) .- 1))
    end
    
    return w
end

"""
To identify the discharge times of the motor unit

# Arguments
    w = separation vector
    Z = whitened signal (residual)
    fsamp = sampling frequency

# Returns
    PTs = Motor unit pulse train
    spikes = discharge indexes of the motor unit
"""

function getspikes(w, Z, fsamp)
    # Generate the motor unit pulse train
    PTs = vec((w' * Z) .* abs.(w' * Z))
    peaks, _ = findpeaks1d(PTs; distance=round(Int,fsamp*0.005))     # Get the peaks
    spikes = peaks[:]
    if length(spikes) > 1
        kmeans_result = kmeans(PTs[spikes]', 2) # Separation of spikes and noise classes
        # Spikes should be in the class with the highest centroid
        idx = argmax(vec(kmeans_result.centers))
        spikes = spikes[kmeans_result.assignments .== idx]
        # Remove outliers
        spikes = spikes[PTs[spikes] .<= mean(PTs[spikes]) + 3 * std(PTs[spikes])]
    end
    
    return PTs, spikes
end

"""
Optimization loop of the MU filter to minimize the coefficient of
varation of the inter spike intervals

# Arguments
    w = separation vector
    Z = whitened signal (residual)
    CoV = coefficient of varation of the inter spike intervals
    fsamp = sampling frequency

# Returns
    wlast = new separation vector
    spikeslast = discharge indexes of the motor unit
    CoVlast = coefficient of varation of the inter spike intervals 
"""

function minimizeCOVISI(w, Z, CoV, fsamp)
    k = 1
    CoVlast = CoV + 0.1
    wlast = w
    spikes = 1
    spikeslast = 1
    while CoV < CoVlast
        CoVlast = CoV # save the last CoV
        spikeslast = spikes # save the last discharge times
        wlast = w # save the last MU filter
        _, spikes = getspikes(w, Z, fsamp)
        ISI = diff(spikes/fsamp) # calculate the interspike interval
        CoV = std(ISI)/mean(ISI) # Update the CoV of the ISI
        k = k + 1;
        w = sum(Z[:,spikes], dims=2)
        w = w/norm(w)
    end

    if length(spikeslast) < 2
        _, spikes = getspikes(w, Z, fsamp)
    end

    w = wlast
    CoV = CoVlast
    return w, spikes, CoV
end

"""
To calculate the silouhette value to estimate the quality of the Motor
Unit (the distance between the peaks and the noise)

# Arguments
    w = separation vector
    Z = whitened signal (residual)
    fsamp = sampling frequency


# Returns
    sil = silhouette value 
"""

function calcsil(w, Z, fsamp)
    # Generate the motor unit pulse train
    PTs = vec((w' * Z) .* abs.(w' * Z))
    peaks, _ = findpeaks1d(PTs; distance=round(Int,fsamp*0.005))     # Get the peaks
    spikes = peaks[:]
    if length(spikes) > 1
        kmeans_result = kmeans(PTs[spikes]', 2) # Separation of spikes and noise classes
        idx = argmax(vec(kmeans_result.centers))
        idx1 = argmin(vec(kmeans_result.centers))
        within = sum(norm(PTs[spikes[kmeans_result.assignments .== idx]] .- kmeans_result.centers[idx])^2) # Sum of the distance between the spikes and the centroid of the class spike
        between = sum(norm(PTs[spikes[kmeans_result.assignments .== idx]] .- kmeans_result.centers[idx1])^2) # Sum of the distance between the spikes and the centroid of the class noise
        sil = (between - within)/max(within,between)
    else
        sil = 0
    end
    return PTs, sil
end

"""
Extracts consecutive MUAPs out of signal Y 

# Arguments
    spikes = discharge indexes of the motor unit
    r = radius of a rectangular window (window length = 2*len+1)
    Z = EMG signal from a single channel

# Returns
    MUAPs = row-wise matrix of extracted MUAPs (aligned signal intervals of length 2*len+1)
"""

function cutMUAP(spikes, r, Z)
    idx = (1 + 2 * r) .<= spikes .<= (length(Z) - 1 - 2 * r)
    spikes = spikes[idx]
    # Initialize the MUAPs matrix
    MUAPs = zeros(length(spikes), 1 + 2 * r)
    # Extract segments if there are valid pulses
    if !isempty(spikes)
        for (row, spike) in enumerate(spikes)
            MUAPs[row, :] .= Z[(spike .- r):(spike .+ r)]
        end
    end

    return MUAPs
end

"""
Produce the same output as the matlab function conv(A,B,'same')

# Arguments
    A = Vector A
    B = Vector B

# Returns
    C = centered convolution of A*B
"""

function conv_same(A, B)
    C = conv(A, B) # Full convolution
    n = length(A) # Length of A
    start_idx = ceil(Int, (length(B) + 1) / 2)
    end_idx = start_idx + n - 1
    return C[start_idx:end_idx]
end

"""
Peel off motor unit spike train from EMG signals

# Arguments
    Z = whitened signal (residual)
    spikes = discharge indexes of the motor unit
    fsamp = sampling frequency
    r = radius of a rectangular window (in ms)

# Returns
    Z = whitened signal (residual)
"""

function peeloff(Z, spikes, fsamp, r)
    r = round(Int, r*fsamp)
    binst = zeros(size(Z,2))
    binst[spikes] .= 1
    spiketrain = zeros(size(Z))
    for i in axes(Z,1)
        allMUAPs = cutMUAP(spikes, r, Z[i,:])
        MUAP = vec(mean(allMUAPs, dims=1))
        spiketrain[i,:] = conv_same(binst, MUAP)
    end
    Z -= spiketrain
    return Z
end

"""
Main function to run the EMG decomposition

# Arguments
    Z = whitened signal (residual)
    fsamp = sampling frequency
    Niter = number of iterations

# Returns
    B = separation matrix
    Silmus = silouhette value of the identified motor units
    CoVmus = coefficient of variation of MU interspike intervals
"""

function emgdecomp(Z, fsamp, Niter)

    # preallocate matrix
    B = zeros(Float64, size(Z,1), Niter)
    Silmus = zeros(Float64, Niter)
    CoVmus = zeros(Float64, Niter)

    nMU = 0

    for i in 1:Niter
        w = randn(Float64, size(Z,1), 1)
        # Initial Orthogonalisation and Normalisation
        w .-= B * (B' * w)
        w /= norm(w)

        # Fixed point algorithm
        maxiter = 500
        w = fixedpointalg(w, Z, B, maxiter)
        _, spikes = getspikes(w, Z, fsamp)

        if length(spikes) > 10
            ISI = diff(spikes)/fsamp
            CoV = std(ISI)/mean(ISI)
            w = sum(Z[:,spikes], dims=2)
            w /= norm(w)
            w, spikes, CoVmus[i] = minimizeCOVISI(w, Z, CoV, fsamp)
            B[:,i] = w
            _, Silmus[i] = calcsil(w, Z, fsamp)

            if Silmus[i] > 0.9
                Z = peeloff(Z, spikes, fsamp, 0.025)
                nMU+=1
                println("Iteration #", i, " - Motor unit #", nMU)
            end
        else
            B[:,i] = w
        end
    end

    return B, Silmus, CoVmus
end

## 7. Run the EMG decomposition (15 iterations) and plot the output

In [None]:
# preallocate matrix and run the EMG decomposition
Niter = 15
Ball = zeros(Float64, size(whitenedsignal,1), Niter)
Silall = zeros(Float64, Niter)
CoVall = zeros(Float64, Niter)

Ball, Silall, CoVall = emgdecomp(whitenedsignal, signal["sampling_frequency"], Niter)

In [None]:
using PlotlyJS
unit = 2
PTs, spikes = getspikes(Ball[:,unit], whitenedsignal, signal["sampling_frequency"])
timeMU = range(0, size(whitenedsignal,2)/signal["sampling_frequency"]; length=size(whitenedsignal,2))

plot_data = [
    PlotlyJS.scatter(x=timeMU, y=PTs, mode="lines", name="Pulse Train"),
    PlotlyJS.scatter(x=timeMU[spikes], y=PTs[spikes], mode="markers", name="Spikes")
]

layout = PlotlyJS.Layout(title="Motor unit", xaxis_title="Time (s)", yaxis_title="Amplitude")

PlotlyJS.plot(plot_data, layout)

In [None]:
r = round(Int, 0.025*signal["sampling_frequency"])
MUAP = zeros(64,r*2+1)
for i in axes(signal["emg_per_grid"]["Grid3"],1)
        allMUAPs = cutMUAP(spikes, r, signal["emg_per_grid"]["Grid3"][i,:])
        MUAP[i,:] = vec(mean(allMUAPs, dims=1))
end

plot_data = [
    PlotlyJS.scatter(
        x=range(-r, r; length=size(MUAP,2)),
        y=MUAP[i, :],
        mode="lines",
    ) for i in 1:size(MUAP, 1)
]

layout = PlotlyJS.Layout(
    title="Motor unit action potentials (MUAPs)",
    xaxis_title="Time (ms)",
    yaxis_title="Amplitude"
)
PlotlyJS.plot(plot_data, layout)

## 8. Adaptive filtering

In [None]:
using StatsBase

"""
To identify the discharge times of the motor unit

# Arguments
    w = separation vector
    Z = whitened signal (residual)
    fsamp = sampling frequency

# Returns
    PTs = Motor unit pulse train
    spikes = discharge indexes of the motor unit
    sorted_centers = sorted centers of the kmeans clustering
"""

function getspikes2(w, Z, fsamp)
    # Generate the motor unit pulse train
    PTs = vec((w' * Z) .* abs.(w' * Z))
    peaks, _ = findpeaks1d(PTs; distance=round(Int,fsamp*0.005))     # Get the peaks
    spikes = peaks[:]
    if length(spikes) > 1
        kmeans_result = kmeans(PTs[spikes]', 2) # Separation of spikes and noise classes
        # Spikes should be in the class with the highest centroid
        idx = argmax(vec(kmeans_result.centers))
        spikes = spikes[kmeans_result.assignments .== idx]
        # Remove outliers
        spikes = spikes[PTs[spikes] .<= mean(PTs[spikes]) + 3 * std(PTs[spikes])]
        sorted_centers = sort(vec(kmeans_result.centers))
    end
    
    return PTs, spikes, sorted_centers
end

"""
Adapt the separation vector and the kmcentres

# Arguments
    w = separation vector
    Z = whitened signal
    fsamp = sampling frequency
    w_learning_rate = learning rate for the separation vector
    spikeweight = weight for the spike centre in kmcentres

# Returns
    w = updated separation vector
    kmcentres = updated centres of the kmeans clustering
"""

function adaptseparation(w, Z, kmcentres, fsamp, w_learning_rate=0.05, spikeweight=5.0)
    # Generate Pulse trains and get spikes
    PTs = vec((w' * Z) .* abs.(w' * Z))
    peaks, _ = findpeaks1d(PTs; distance=round(Int,fsamp*0.005))     # Get the peaks
    spikes = peaks[:]
    
    Distime = (abs.(PTs[spikes]' .- kmcentres[1, :]) .> abs.(PTs[spikes]' .- kmcentres[2, :])) .- (PTs[spikes]' .> kmcentres[2, :] * 3)

    wnew = copy(w)
    if sum(Distime) > 0
        wnew = sum(Z[:, spikes[Distime'.== 1]], dims=2)
        wnew = wnew ./ norm(wnew)
        w = w .+ w_learning_rate .* wnew
        w = w ./ norm(w)

        # Update kmcentres
        spikectnew = mean(PTs[spikes[Distime'.== 1]])
        if sum(Distime'.== 0) > 0
            noisectnew = mean(PTs[spikes[Distime'.== 0]])
        else
            noisectnew = 0
        end
        kmcentres[2] = (spikeweight * kmcentres[2] + sum(Distime'.== 1) * spikectnew) / (spikeweight + sum(Distime'.== 1))
        kmcentres[1] = (spikeweight * kmcentres[1] + sum(Distime'.== 1) * noisectnew) / (spikeweight + sum(Distime'.== 1))
    end

    return w, kmcentres
end

In [None]:

w_learning_rate = 0.01
spikeweight = 5
window_size = Int(round(0.25 * signal["sampling_frequency"]))
n_windows = Int(ceil(size(whitenedsignal, 2) / window_size))
w_results = []
kmcentres_results = []

PTs, spikes, kmcentres = getspikes2(Ball[:,unit], whitenedsignal, signal["sampling_frequency"])
w = Ball[:, unit]

for i in 1:n_windows
    # Calculate the start and end indices for the current window
    start_idx = (i - 1) * window_size + 1
    end_idx = min(i * window_size, size(whitenedsignal, 2))

    # Extract the current window
    whitenedsignal_window = whitenedsignal[:, start_idx:end_idx]

    # Apply the adaptseparation function to the current window
    w, kmcentres = adaptseparation(w, whitenedsignal_window, kmcentres, signal["sampling_frequency"], w_learning_rate, spikeweight)

    # Save the results
    push!(w_results, copy(w))
    push!(kmcentres_results, copy(kmcentres))

    # Print or log the results for this iteration
    println("Iteration $i:")
end

In [None]:
# Prepare kmcentres as lines
window_times = [(i - 0.5) * window_size / signal["sampling_frequency"] for i in 1:n_windows]
centre1_vals = [kmcentres[1][1] for kmcentres in kmcentres_results]
centre2_vals = [kmcentres[2][1] for kmcentres in kmcentres_results]

# Original pulse train and spikes
plot_data = [
    PlotlyJS.scatter(x=timeMU, y=PTs, mode="lines", name="Pulse Train"),
    PlotlyJS.scatter(x=timeMU[spikes], y=PTs[spikes], mode="markers", name="Spikes"),

    # Add updated kmeans centres
    PlotlyJS.scatter(x=window_times, y=centre1_vals, mode="lines+markers", name="Centre Noise", line=attr(color="red")),
    PlotlyJS.scatter(x=window_times, y=centre2_vals, mode="lines+markers", name="Centre Spikes", line=attr(color="blue"))
]

layout = Layout(
    title="Motor Unit with Updated K-Means Centres",
    xaxis_title="Time (s)",
    yaxis_title="Amplitude"
)

PlotlyJS.plot(plot_data, layout)