In [2]:
import tkinter as tk
from tkinter import ttk, messagebox, filedialog
import mne
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg

"""
  * @author: dengyufeng
  * @Created on 2024/8/14 11:31
 """

def batch_process():
    file_paths = filedialog.askopenfilenames(filetypes=[("All files", "*.*"), ("FIF files", "*.fif"), ("EDF files", "*.edf"), ("BDF files", "*.bdf")])
    
    if not file_paths:
        return

    for file_path in file_paths:
        try:
            if file_path.endswith('.fif'):
                raw_data = mne.io.read_raw_fif(file_path, preload=True)
            elif file_path.endswith('.edf'):
                raw_data = mne.io.read_raw_edf(file_path, preload=True)
            elif file_path.endswith('.bdf'):
                raw_data = mne.io.read_raw_bdf(file_path, preload=True)
            else:
                messagebox.showerror("错误", f"不支持的文件格式: {file_path}")
                continue

            process_data.raw = raw_data  
            
            global show_plot
            show_plot = False
            process_data()

            save_file_path = save_directory + '/' + file_path.split('/')[-1].replace('.fif', '_processed.fif') if file_path.endswith('.fif') else file_path + '_processed.fif'
            process_data.raw.save(save_file_path, overwrite=True)
            messagebox.showinfo("信息", f"数据文件处理并保存成功: {save_file_path}")

        except Exception as e:
            messagebox.showerror("错误", f"处理数据文件失败: {file_path}\n{e}")

    messagebox.showinfo("信息", "批量处理完成。")

def process_data():
    if not hasattr(process_data, "raw"):
        messagebox.showerror("错误", "请先加载数据。")
        return

    try:
        notch_freq = int(notch_freq_entry.get())
        l_freq = float(l_freq_entry.get())
        h_freq = float(h_freq_entry.get())
        ref_type = ref_var.get()
        ref_channel = ref_channel_menu.get()

        process_data.raw.notch_filter(freqs=notch_freq)
        process_data.raw.filter(l_freq=l_freq, h_freq=h_freq)

        if ref_type == '平均参考':
            process_data.raw.set_eeg_reference('average', projection=True)
        elif ref_type == '单通道参考':
            if ref_channel:
                process_data.raw.set_eeg_reference(ref_channel, projection=True)
            else:
                messagebox.showerror("错误", "请先选择参考通道。")
                return

        if show_plot:
            plot_data()
            messagebox.showinfo("信息", "数据处理成功。")

        update_button_layout()

    except ValueError as e:
        messagebox.showerror("错误", f"输入参数错误: {e}")
    except Exception as e:
        messagebox.showerror("错误", f"数据处理失败: {e}")

def update_button_layout():
    process_button.grid(row=7, column=0, padx=10, pady=20)
    save_button.grid(row=7, column=1, padx=10, pady=20)

def load_data():
    file_format = format_var.get()
    
    file_types = [("All files", "*.*")]
    if file_format == "FIF":
        file_types = [("FIF files", "*.fif"), ("All files", "*.*")]
    elif file_format == "EDF":
        file_types = [("EDF files", "*.edf"), ("All files", "*.*")]
    elif file_format == "BDF":
        file_types = [("BDF files", "*.bdf"), ("All files", "*.*")]

    file_path = filedialog.askopenfilename(filetypes=file_types)
    
    if file_path:
        try:
            if file_format == "FIF":
                process_data.raw = mne.io.read_raw_fif(file_path, preload=True)
            elif file_format == "EDF":
                process_data.raw = mne.io.read_raw_edf(file_path, preload=True)
            elif file_format == "BDF":
                process_data.raw = mne.io.read_raw_bdf(file_path, preload=True)
            else:
                messagebox.showerror("错误", "请选择一个支持的文件格式。")
                return

            update_channel_list()
            messagebox.showinfo("信息", "数据加载成功。")
        except Exception as e:
            messagebox.showerror("错误", f"数据加载失败: {e}")

def update_channel_list():
    if hasattr(process_data, "raw"):
        channel_names = process_data.raw.info['ch_names']
        ref_channel_menu['values'] = channel_names
        ref_channel_menu.set('')

def ref_type_changed(event=None):
    ref_type = ref_var.get()
    if ref_type == '单通道参考':
        ref_channel_label.grid(row=6, column=0, padx=10, pady=10)
        ref_channel_menu.grid(row=6, column=1, padx=10, pady=10)
    else:
        ref_channel_label.grid_remove()
        ref_channel_menu.grid_remove()

def plot_data():
    if hasattr(process_data, "raw"):
        print("Plotting data...")

        data, times = process_data.raw[:, :1000] 
        
        num_channels = data.shape[0]
        fig, axs = plt.subplots(num_channels, 1, figsize=(8, 0.75 * num_channels), sharex=True)  
        
        if num_channels == 1:
            axs = [axs] 

        for i in range(num_channels):
            axs[i].plot(times, data[i])  
            axs[i].set_ylabel(f'Channel {i+1}')
        
        axs[-1].set_xlabel('Time (s)')
        fig.tight_layout(pad=1.0)

        for widget in root.grid_slaves(row=9, column=0):
            widget.destroy()
        
        canvas = FigureCanvasTkAgg(fig, master=root)
        canvas.draw()
        canvas.get_tk_widget().grid(row=9, column=0, columnspan=3, pady=20)
    else:
        messagebox.showerror("错误", "请先加载数据。")

def save_data():
    if not hasattr(process_data, "raw"):
        messagebox.showerror("错误", "请先加载和处理数据。")
        return

    file_path = filedialog.asksaveasfilename(initialdir=save_directory, defaultextension=".fif", filetypes=[("FIF files", "*.fif"), ("All files", "*.*")])
    
    if file_path:
        try:
            # Ensure only one .fif extension in the save path
            if file_path.endswith('.fif'):
                save_file_path = file_path
            else:
                save_file_path = file_path + '.fif'

            process_data.raw.save(save_file_path, overwrite=True)
            messagebox.showinfo("信息", f"数据保存成功: {save_file_path}")
        except Exception as e:
            messagebox.showerror("错误", f"数据保存失败: {e}")

def select_save_directory():
    global save_directory
    save_directory = filedialog.askdirectory()
    if save_directory:
        save_directory_label.config(text=f"保存路径: {save_directory}")

def close_window():
    root.destroy()

show_plot = True
save_directory = ''

root = tk.Tk()
root.title("EEG 数据处理")

close_button = tk.Button(root, text="关闭", command=close_window)
close_button.grid(row=0, column=2, padx=10, pady=10, sticky=tk.E)

tk.Label(root, text="选择文件格式:").grid(row=0, column=0, padx=10, pady=10)
format_var = tk.StringVar(value='FIF')
format_menu = ttk.Combobox(root, textvariable=format_var, values=["FIF", "EDF", "BDF"])
format_menu.grid(row=0, column=1, padx=10, pady=10)

load_button = tk.Button(root, text="加载数据", command=load_data)
load_button.grid(row=1, column=0, columnspan=2, pady=20)

tk.Label(root, text="工频滤波频率 (Hz):").grid(row=2, column=0, padx=10, pady=10)
notch_freq_entry = tk.Entry(root)
notch_freq_entry.grid(row=2, column=1, padx=10, pady=10)

tk.Label(root, text="带通滤波下限 (Hz):").grid(row=3, column=0, padx=10, pady=10)
l_freq_entry = tk.Entry(root)
l_freq_entry.grid(row=3, column=1, padx=10, pady=10)

tk.Label(root, text="带通滤波上限 (Hz):").grid(row=4, column=0, padx=10, pady=10)
h_freq_entry = tk.Entry(root)
h_freq_entry.grid(row=4, column=1, padx=10, pady=10)

tk.Label(root, text="参考类型:").grid(row=5, column=0, padx=10, pady=10)
ref_var = tk.StringVar(value='平均参考')
ref_menu = ttk.Combobox(root, textvariable=ref_var, values=["平均参考", "单通道参考"])
ref_menu.grid(row=5, column=1, padx=10, pady=10)
ref_menu.bind("<<ComboboxSelected>>", ref_type_changed)

ref_channel_label = tk.Label(root, text="参考通道:")
ref_channel_menu = ttk.Combobox(root, textvariable=tk.StringVar())

process_button = tk.Button(root, text="处理数据", command=process_data)
process_button.grid(row=7, column=0, padx=10, pady=20)

save_button = tk.Button(root, text="保存数据", command=save_data)
save_button.grid(row=7, column=1, padx=10, pady=20)

batch_button = tk.Button(root, text="批量处理", command=batch_process)
batch_button.grid(row=8, column=0, columnspan=2, pady=20)

select_directory_button = tk.Button(root, text="选择保存路径", command=select_save_directory)
select_directory_button.grid(row=9, column=0, padx=10, pady=10)

save_directory_label = tk.Label(root, text="保存路径: 未选择")
save_directory_label.grid(row=9, column=1, padx=10, pady=10)

root.mainloop()


Opening raw data file C:/Users/pc/Desktop/李佳/S002/Subject_001.fif...
Isotrak not found
    Range : 0 ... 59999 =      0.000 ...   374.994 secs
Ready.
Reading 0 ... 59999  =      0.000 ...   374.994 secs...
Filtering raw data in 3 contiguous segments
Setting up band-stop filter from 49 - 51 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 49.38
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 49.12 Hz)
- Upper passband edge: 50.62 Hz
- Upper transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 50.88 Hz)
- Filter length: 1057 samples (6.606 s)

Setting up band-stop filter from 49 - 51 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 pas

  raw_data = mne.io.read_raw_fif(file_path, preload=True)
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


Setting up band-stop filter from 49 - 51 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 49.38
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 49.12 Hz)
- Upper passband edge: 50.62 Hz
- Upper transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 50.88 Hz)
- Filter length: 1057 samples (6.606 s)

Filtering raw data in 3 contiguous segments
Setting up band-pass filter from 3 - 18 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 3.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 2.00 Hz)
- Upper passband edge: 18.00 Hz
- Upper transition bandwidth: 4

[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


EEG channel type selected for re-referencing
Adding average EEG reference projection.
1 projection items deactivated
Average reference projection was added, but has not been applied yet. Use the apply_proj method to apply it.
Writing C:\Users\pc\Desktop\李佳\S002\Subject_001_processed.fif
Closing C:\Users\pc\Desktop\李佳\S002\Subject_001_processed.fif
[done]


  process_data.raw.save(save_file_path, overwrite=True)


Opening raw data file C:/Users/pc/Desktop/李佳/S002/Subject_002.fif...
Isotrak not found
    Range : 0 ... 59039 =      0.000 ...   368.994 secs
Ready.
Reading 0 ... 59039  =      0.000 ...   368.994 secs...
Filtering raw data in 3 contiguous segments
Setting up band-stop filter from 49 - 51 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 49.38
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 49.12 Hz)
- Upper passband edge: 50.62 Hz
- Upper transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 50.88 Hz)
- Filter length: 1057 samples (6.606 s)

Setting up band-stop filter from 49 - 51 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 pas

  raw_data = mne.io.read_raw_fif(file_path, preload=True)
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


Filtering raw data in 3 contiguous segments
Setting up band-pass filter from 3 - 18 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 3.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 2.00 Hz)
- Upper passband edge: 18.00 Hz
- Upper transition bandwidth: 4.50 Hz (-6 dB cutoff frequency: 20.25 Hz)
- Filter length: 265 samples (1.656 s)

EEG channel type selected for re-referencing
Adding average EEG reference projection.
1 projection items deactivated
Average reference projection was added, but has not been applied yet. Use the apply_proj method to apply it.
Writing C:\Users\pc\Desktop\李佳\S002\Subject_002_processed.fif


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s
  process_data.raw.save(save_file_path, overwrite=True)


Closing C:\Users\pc\Desktop\李佳\S002\Subject_002_processed.fif
[done]
