In [None]:
import numpy as np
from scipy.signal import decimate, detrend, filtfilt
from scipy.signal import firwin, iirnotch, butter

def CFilter(signal, time, parms):
    """
    Generic downsampling and filtering function.
    signal: np.ndarray (samples x channels)
    time: np.ndarray (samples,)
    parms: dict with possible keys 'detrend', 'decimate', 'filtfilt'
    """
    nrSamples, nrChannels = signal.shape
    sampleRate = 1.0 / np.median(np.diff(time))

    for key in parms:
        if key == "decimate":
            # Downsampling using decimate
            if len(parms["decimate"]) == 2 and parms["decimate"][0] == "frequency":
                targetRate = parms["decimate"][1]
                R = int(round(sampleRate / targetRate))
            else:
                R = parms["decimate"][0]
                targetRate = sampleRate / R
            print(f"Downsampling from {sampleRate:.0f} Hz to {targetRate:.0f} Hz (decimate)...")
            tmp = np.empty((int(np.ceil(nrSamples / R)), nrChannels))
            for ch in range(nrChannels):
                tmp[:, ch] = decimate(signal[:, ch], R, ftype='fir', zero_phase=True)
            signal = tmp
            time = np.linspace(time[0], time[-1], signal.shape[0])
            print("Done.")
        elif key == "filtfilt":
            # Filtering using filtfilt
            for filt_name, filt_params in parms["filtfilt"].items():
                print(f"Applying filter ({filt_name})...")
                if filt_name == "notch":
                    # Example: {'notch': ['bandstopfir', 1000, 59, 61]}
                    # We'll use iirnotch for 60 Hz
                    order = filt_params[1] if len(filt_params) > 1 else 2
                    f0 = (filt_params[2] + filt_params[3]) / 2
                    Q = f0 / (filt_params[3] - filt_params[2])
                    b, a = iirnotch(f0, Q, fs=sampleRate)
                elif filt_name == "lowpass":
                    # Example: {'lowpass': ['lowpassfir', 2, 120]}
                    order = filt_params[1]
                    cutoff = filt_params[2]
                    b = firwin(order + 1, cutoff, fs=sampleRate)
                    a = 1
                else:
                    # Add more filter types as needed
                    continue
                for ch in range(signal.shape[1]):
                    signal[:, ch] = filtfilt(b, a, signal[:, ch])
                print("Done.")
        elif key == "detrend":
            # Detrending
            print("Detrending...")
            signal = detrend(signal, axis=0, type='linear')
            print("Done.")
        # Update after each step
        nrSamples, nrChannels = signal.shape
        sampleRate = 1.0 / np.median(np.diff(time))
    return signal, time