In [1]:
import numpy as np
import pandas as pd
import mne
import asrpy
import pathlib

from scipy.stats import pearsonr
from scipy.stats import wilcoxon

from mne.preprocessing import ICA

In [2]:
# Define the DataFrame with appropriate columns
columns = [
    "Segment_ID",
    "RMS_Noisy",
    "RMS_Cleaned",
    "RMS_diff",
    "Correlation",
    "P_Value",
]
results_df = pd.DataFrame(columns=columns)

In [3]:
data = {}
for filename in pathlib.Path("./lee2019-artifacts/").glob("*.fif"):
    data[filename] = mne.io.read_raw(filename, preload=True, verbose=False).pick(
        ["O1", "O2", "Fp1", "Fp2"]
    )

segment_length = 30
all_segments = []

for filename, raw in data.items():
    n_segments = int(np.floor(raw.times[-1] / segment_length))

    for i in range(n_segments):
        start_sec = i * segment_length
        stop_sec = start_sec + segment_length
        stop_sec = min(stop_sec, raw.times[-1])

        cropped_raw = raw.copy().crop(
            tmin=start_sec, tmax=stop_sec - 1 / raw.info["sfreq"]
        )
        all_segments.append(cropped_raw)

In [4]:
len(all_segments)

231

In [5]:
def rms(raw):
    """
    Computes the Root Mean Square of the signal, providing a measure of the overall power in the signal.
    """
    data = raw.get_data()
    rms = np.sqrt(np.mean(data**2))
    return np.round(rms, 4)

In [6]:
def correlation(noisy, clean):
    """
    Calculates the Pearson correlation coefficient for each channel between the
    noisy and cleaned signals to assess how well the overall structure of the signal
    is preserved post-cleaning.
    """
    correlation_coeffs = []
    for ch in range(4):
        coeff, _ = pearsonr(noisy.get_data(picks=ch)[0], clean.get_data(picks=ch)[0])
        correlation_coeffs.append(coeff)

    average_corr = np.mean(correlation_coeffs)

    return np.round(average_corr, 4)

In [7]:
def significance(noisy, clean):
    """
    Uses the Wilcoxon signed-rank test to determine if there is a
    statistically significant difference between the noisy and cleaned EEG signals.
    """
    data_original = noisy.get_data().flatten()
    data_cleaned = clean.get_data().flatten()

    _, p = wilcoxon(data_original, data_cleaned)

    return p

In [8]:
def asr_pipeline(raw):
    raw.filter(1, 40, fir_design="firwin")
    raw.resample(256)

    processed_raw = raw.copy()

    try:
        asr = asrpy.ASR(sfreq=256, cutoff=15)
        asr.fit(processed_raw)
        reconstructed = asr.transform(processed_raw)
    except:
        reconstructed = processed_raw

    return reconstructed

In [9]:
# for i, key in enumerate(all_segments):

#     cleaned = asr_pipeline(data[key])

#     rms_noisy = rms(data[key])
#     rms_cleaned = rms(cleaned)
#     corr = correlation(data[key], cleaned)
#     try:
#         p_value = significance(data[key], cleaned)
#     except:
#         p_value = -1

#     results_df = results_df.append(
#         {
#             "Segment_ID": i,
#             "RMS_Noisy": rms_noisy,
#             "RMS_Cleaned": rms_cleaned,
#             "RMS_diff": np.abs(rms_cleaned - rms_noisy),
#             "Correlation": corr,
#             "P_Value": p_value,
#         },
#         ignore_index=True,
#     )

In [10]:
for i, noisy_segment in enumerate(all_segments):
    cleaned = asr_pipeline(noisy_segment)

    rms_noisy = rms(noisy_segment)
    rms_cleaned = rms(cleaned)
    corr = correlation(noisy_segment, cleaned)
    try:
        p_value = significance(noisy_segment, cleaned)
    except:
        p_value = None

    results_df = results_df.append(
        {
            "Segment_ID": i,
            "RMS_Noisy": rms_noisy,
            "RMS_Cleaned": rms_cleaned,
            "RMS_diff": np.abs(rms_cleaned - rms_noisy),
            "Correlation": corr,
            "P_Value": p_value,
        },
        ignore_index=True,
    )

  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df =

Try calibrating ASR with cleaner data.


  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df = results_df.append(
  results_df =

In [11]:
results_df.to_csv("asr_analysis_results.csv", index=False)

In [12]:
# Basic descriptive statistics
print(results_df.describe())

# Filter segments with significant improvements
significant_improvements = results_df[results_df["P_Value"] < 0.05]
print("Significant Improvements Count:", significant_improvements.shape[0])

       Segment_ID   RMS_Noisy  RMS_Cleaned    RMS_diff  Correlation  \
count  231.000000  231.000000   231.000000  231.000000   231.000000   
mean   115.000000   72.911465    32.784584   40.126881     0.527460   
std     66.828138   37.172106    32.584414   34.289274     0.265043   
min      0.000000    8.191500     4.445000    0.000000     0.074600   
25%     57.500000   46.936300     9.972850   16.285150     0.319600   
50%    115.000000   66.042600    17.894300   33.354300     0.462300   
75%    172.500000   96.429500    45.632400   56.107300     0.714250   
max    230.000000  242.468800   185.752900  226.353900     1.000000   

             P_Value  
count   2.130000e+02  
mean    4.325894e-02  
std     1.518949e-01  
min     0.000000e+00  
25%    1.249474e-169  
50%     7.941557e-34  
75%     1.981576e-06  
max     9.954392e-01  
Significant Improvements Count: 186
