In [None]:
%matplotlib inline

import os

import mne
import numpy as np

from _helper_functions import plot_psd

## Spectral analysis of data - the `time_frequency` module

Often in neuroscience, we are interested in determining the spectral composition of signals, represented as a power spectral density (PSD).

Furthermore, we may want to isolate the spectral content of a signal in a particular range of frequencies, which involves filtering the data.

### Part 1 - Filtering data

To examine activity at a limited range of frequencies, we perform spectral filtering.

This can take the form of:
- Lowpass filtering - retaining information content of a signal below a certain frequency.
- Highpass filtering - retaining information content of a signal above a certain frequency.
- Bandpass filtering - retaining information content of a signal within a certain frequency range.
- Bandstop filtering - retaining information content of a signal outside a certain frequency range.

<br>
<img src="figures/filter_types.png" alt="Filter types" width="30%" height="30%">

Credit: [allaboutcircuits.com](https://www.allaboutcircuits.com/technical-articles/low-pass-filter-tutorial-basics-passive-RC-filter/)

For an understanding of how to perform such filtering in MNE, we start by simulating 5 seconds of data sampled at 200 Hz, consisting of sine waves at 5 Hz, 10 Hz, and 20 Hz.

In [None]:
# Simulation settings
duration = 10  # seconds
sfreq = 200  # sampling rate (Hz)

# Timepoints of the simulated data
times = np.linspace(start=0, stop=duration, num=sfreq * duration, endpoint=False)

# Simulate data as sine waves of given frequencies
chan_1 = np.sin(2 * np.pi * times * 5)  # 5 Hz signal
chan_2 = np.sin(2 * np.pi * times * 10)  # 10 Hz signal
chan_3 = np.sin(2 * np.pi * times * 20)  # 20 Hz signal

# Combine channels into a single array
data = np.array([chan_1, chan_2, chan_3])
ch_names = ["5Hz", "10Hz", "20Hz"]  # channel names

To play around with the data, let us first store it in a `Raw` object.

**Exercises - Spectral filtering**

**Exercise:** Create an [`Info`](https://mne.tools/stable/generated/mne.Info.html) object for the 3 channels, setting the channel types to be EEG, and using the sampling frequency we specified above.

Afterwards, use the `data` array and the `Info` object to create a [`RawArray`](https://mne.tools/stable/generated/mne.io.RawArray.html) object for the signals, called `raw`.

*Hint:* use the [`create_info()`](https://mne.tools/stable/generated/mne.create_info.html) function to create the `Info` object.

In [None]:
## CODE GOES HERE
info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types="eeg")
raw = mne.io.RawArray(data=data, info=info)

We can now easily plot the data.

Count the number of cycles in each channel per second. Do they match our expectations?

In [None]:
# The object containing the signals should be called `raw`
raw.plot(duration=6, scalings="auto");

#### Computing power spectra - a brief introduction

We can easily compute the power spectral density of channels in `Raw` objects.

This is done by calling the [`compute_psd()`](https://mne.tools/stable/generated/mne.io.Raw.html#mne.io.Raw.compute_psd) method.

`compute_psd()` returns a [`mne.time_frequency.Spectrum`](https://mne.tools/stable/generated/mne.time_frequency.Spectrum.html#mne.time_frequency.Spectrum) object containing the power spectra.

In [None]:
# Compute PSD of the data
spectrum = raw.compute_psd(fmax=30)
spectrum

We can plot the PSD using the [`plot()`](https://mne.tools/stable/generated/mne.time_frequency.Spectrum.html#mne.time_frequency.Spectrum.plot) method of the `Spectrum` object.

In [None]:
# Plot the PSD
spectrum.plot();

As you can see, there are distinct peaks in the power spectrum at 5, 10, and 20 Hz.

We will explore PSD computation in more detail below after examining spectral filtering.

#### Spectral filtering

Let us now look at how we can filter the data.

##### Lowpass, highpass, bandpass, and bandstop filtering

Lowpass, highpass, bandpass, and bandstop filtering is most easily done using the [`filter()`](https://mne.tools/stable/generated/mne.io.Raw.html#mne.io.Raw.filter) method of `Raw` objects.

Frequencies to filter using the `filter()` method are specified using the `l_freq` and `h_freq` parameters:
- `l_freq` specifies the lowest frequency of information to retain (in Hz).
- `h_freq` specifies the highest frequency of information to retain (in Hz).

<br>

In this way:
- specifying only `l_freq` highpass filters the data.
- specifying only `h_freq` lowpass filters the data.
- specifying `l_freq` to be lower than `h_freq` bandpass filters the data.
- specifying `l_freq` to be higher than `h_freq` bandstop filters the data.

<br>
<img src="figures/filter_types_marked.png" alt="Filter types with frequency parameters" width="30%" height="30%">

Adapted from: [allaboutcircuits.com](https://www.allaboutcircuits.com/technical-articles/low-pass-filter-tutorial-basics-passive-RC-filter/)

The example below shows how to highpass filter the data to remove the 5 Hz activity.

**N.B.** Note that the `Raw` object is copied so that the original data is not modified.

In [None]:
# Copy to preserve original data
raw_copy = raw.copy()

# Highpass filter at 8 Hz to exclude 5 Hz activity
raw_copy.filter(l_freq=8, h_freq=None)

# Compute the PSD of the new data and plot it
raw_copy.compute_psd(fmax=30).plot()

# View the filtered data
raw_copy.plot(duration=6, scalings="auto");

As you can see, the activity in the 5 Hz channel has been removed, along with the peak in the power spectrum at 5 Hz.

On the other hand, the 10 and 20 Hz activity remains.

**Exercises - Lowpass, highpass, bandpass, and bandstop filtering**

**Exercise:** Lowpass filter the data at 15 Hz to remove the 20 Hz activity, then plot the PSD and raw data to confirm that only the 20 Hz activity has been removed.

In [None]:
## CODE GOES HERE
raw_copy = raw.copy()
raw_copy.filter(l_freq=None, h_freq=15)
raw_copy.compute_psd(fmax=30).plot()
raw_copy.plot(duration=6, scalings="auto");

**Exercise:** Lowpass filter the data at 8 Hz to remove the 10 and 20 Hz activity, then plot the PSD and raw data to confirm that only the 5 Hz activity remains.

In [None]:
## CODE GOES HERE
raw_copy = raw.copy()
raw_copy.filter(l_freq=None, h_freq=8)
raw_copy.compute_psd(fmax=30).plot()
raw_copy.plot(duration=6, scalings="auto");

**Exercise:** Bandstop filter the data between 8 and 15 Hz to remove the 10 Hz activity, then plot the PSD and raw data to confirm that only the 10 Hz activity has been removed.

In [None]:
## CODE GOES HERE
raw_copy = raw.copy()
raw_copy.filter(l_freq=15, h_freq=8)
raw_copy.compute_psd(fmax=30).plot()
raw_copy.plot(duration=6, scalings="auto");

**Exercise:** Bandpass filter the data between 8 and 15 Hz to remove the 5 and 20 Hz activity, then plot the PSD and raw data to confirm that only the 10 Hz activity remains.

In [None]:
## CODE GOES HERE
raw_copy = raw.copy()
raw_copy.filter(l_freq=8, h_freq=15)
raw_copy.compute_psd(fmax=30).plot()
raw_copy.plot(duration=6, scalings="auto");

As you can see, the filters from the `filter()` method can operate over a large frequency range.

Sometimes however, the attenuation of a very limited frequency range is desired, for example when removing line noise artefacts, see e.g.:
- https://pressrelease.brainproducts.com/eeg-artifacts-handling-in-analyzer/#technical
- https://labeling.ucsd.edu/tutorial/labels
- https://mne.tools/stable/auto_tutorials/preprocessing/30_filtering_resampling.html#power-line-noise

In these situations, a notch filter is often used.

##### Notch filtering

Notch filters have their own dedicated [`notch_filter()`](https://mne.tools/stable/generated/mne.io.Raw.html#mne.io.Raw.notch_filter) method in the `Raw` object.

Below, we use a notch filter to remove the 5 Hz activity alone.

In [None]:
raw_copy = raw.copy()

# Apply notch filter at 5 Hz
raw_copy.notch_filter(freqs=5)

# Plot the PSD and timeseries of the filtered data
raw_copy.compute_psd(fmax=30).plot()
raw_copy.plot(duration=6, scalings="auto");

**Exercises - Notch filtering**

**Exercise:** Notch filter the data at 10 Hz, and visualise the PSD and raw data to confirm that only the 10 Hz activity has been removed.

In [None]:
## CODE GOES HERE
raw_copy = raw.copy()
raw_copy.notch_filter(freqs=10)
raw_copy.compute_psd(fmax=30).plot()
raw_copy.plot(duration=6, scalings="auto");

**Exercise:** Notch filter the data at 20 Hz, and visualise the PSD and raw data to confirm that only the 20 Hz activity has been removed.

In [None]:
## CODE GOES HERE
raw_copy = raw.copy()
raw_copy.notch_filter(freqs=20)
raw_copy.compute_psd(fmax=30).plot()
raw_copy.plot(duration=6, scalings="auto");

**Exercise:** Notch filter the data at 10 and 20 Hz in a single call to the `notch_filter()` method, and visualise the PSD and raw data to confirm that only the 5 Hz activity remains.

In [None]:
## CODE GOES HERE
raw_copy = raw.copy()
raw_copy.notch_filter(freqs=[10, 20])
raw_copy.compute_psd(fmax=30).plot()
raw_copy.plot(duration=6, scalings="auto");

As you can see, MNE provides a number of convenient tools for the spectral filtering of data.

There are many options for specifying filter parameters to fine tune the filters for your needs which are discussed in more depth in the following tutorials:
- Background on filtering: https://mne.tools/stable/auto_tutorials/preprocessing/25_background_filtering.html
- Filtering and resampling: https://mne.tools/stable/auto_tutorials/preprocessing/30_filtering_resampling.html

### Part 2 - Computing power spectral densities

Up until now, we have computed PSDs using the `compute_psd()` method of `Raw` objects.

Note that an equivalent method exists for `Epochs` objects: [`mne.Epochs.compute_psd()`](https://mne.tools/stable/generated/mne.Epochs.html#mne.Epochs.compute_psd).

The `compute_psd()` methods of `Raw` and `Epochs` objects support PSD computations using the Welch and multitaper methods.

There exist equivalent functions for computing PSDs using the Welch and multitaper methods from arrays of data:
- [`mne.time_frequency.psd_array_welch()`](https://mne.tools/stable/generated/mne.time_frequency.psd_array_welch.html)
- [`mne.time_frequency.psd_array_multitaper()`](https://mne.tools/stable/generated/mne.time_frequency.psd_array_multitaper.html)

**N.B.** Performing the PSD computations on arrays requires the sampling frequency of the data (`sfreq` parameter) to be specified.

Here, using the sample data, we specify the multitaper method to use in `compute_psd()`.

Using the `fmax` parameter, we only return the results until 50 Hz.

We additionally take only the EEG channels and crop to the first 60 seconds to reduce computation time.

In [None]:
# Load the sample data
raw = mne.io.read_raw_fif(
    os.path.join(mne.datasets.sample.data_path(), "MEG", "sample", "sample_audvis_raw.fif")
)
raw.pick(picks="eeg", exclude="bads")
raw.crop(tmax=60)
raw.load_data()

# Compute PSD
spectrum = raw.compute_psd(method="welch", fmax=50, n_fft=2048)

# Plot the PSD
spectrum.plot();

We can extract the power values and the corresponding frequencies from the `Spectrum` object.

In [None]:
# Extract PSD data
psd = spectrum.get_data()

# Extract frequencies in the PSD
freqs = spectrum.freqs

print(f"PSD data has shape: {psd.shape}  # channels x frequencies")
print(f"Frequencies has shape: {freqs.shape}  # frequencies")

Using a custom function `plot_psd()`, we can verify that these values match those plotted using the `plot()` method of the `Spectrum` object.

For the `plot_psd()` function, we pass in the array of power spectral density values for a set of channels, alongisde the corresponding frequencies.

In [None]:
# Plot PSD from arrays with custom function
plot_psd(psd=psd, freqs=freqs)

#### Computing PSDs from standalone functions

**Exercises - Computing PSDs**

**Exercise:** Perform the equivalent computation using the `psd_array_welch()` function on the data array extracted from the `Raw` object,

Remember to specify a maximum frequency of 50 Hz and an FFT length of 2,048.

Use the custom `plot_psd()` function to visualise the results.

Do the results match the output of `compute_psd()`?

*Hint:* data can be extracted from `Raw` objects using the [`get_data()`](https://mne.tools/stable/generated/mne.io.Raw.html#mne.io.Raw.get_data) method.

In [None]:
## CODE GOES HERE
psd, freqs = mne.time_frequency.psd_array_welch(
    x=raw.get_data(), sfreq=raw.info["sfreq"], fmax=50, n_fft=2048
)
plot_psd(psd=psd, freqs=freqs)

**Exercise:** Again using `psd_array_welch()`, compute the PSDs for the frequency range from 5 Hz onwards (i.e. no 50 Hz limit), and visualise the results.

*Hint:* use the `fmin` parameter to specify the starting frequency.

In [None]:
## CODE GOES HERE
psd, freqs = mne.time_frequency.psd_array_welch(
    x=raw.get_data(), sfreq=raw.info["sfreq"], fmin=5, n_fft=2048
)
plot_psd(psd=psd, freqs=freqs)

**Exercise:** Using `psd_array_welch()`, compute the PSDs for the frequency range 5 - 50 Hz, and visualise the results.

In [None]:
## CODE GOES HERE
psd, freqs = mne.time_frequency.psd_array_welch(
    x=raw.get_data(), sfreq=raw.info["sfreq"], fmin=5, fmax=50, n_fft=2048
)
plot_psd(psd=psd, freqs=freqs)

##### Indexing frequencies of results

**Exercise:** Using `psd_array_welch()`, compute the PSDs for the entire frequency range, but only plot the results up to 50 Hz.

*Hint:* Use the [`np.where()`](https://numpy.org/doc/stable/reference/generated/numpy.where.html) function to find where the appropriate frequency values are.

In [None]:
## CODE GOES HERE
psd, freqs = mne.time_frequency.psd_array_welch(
    x=raw.get_data(), sfreq=raw.info["sfreq"], n_fft=2048
)
plot_freqs_idx = np.where(freqs <= 50)[0]
plot_psd(psd=psd[:, plot_freqs_idx], freqs=freqs[plot_freqs_idx])

**Exercise:** Using `psd_array_welch()`, compute the PSDs for the entire frequency range, but only plot the results from 5 Hz onwards.

In [None]:
## CODE GOES HERE
psd, freqs = mne.time_frequency.psd_array_welch(
    x=raw.get_data(), sfreq=raw.info["sfreq"], n_fft=2048
)
plot_freqs_idx = np.where(freqs >= 5)[0]
plot_psd(psd=psd[:, plot_freqs_idx], freqs=freqs[plot_freqs_idx])

**Exercise:** Using `psd_array_welch()`, compute the PSDs for the entire frequency range, but only plot the results from 5 - 50 Hz.

*Hint:* Use the form `np.where((condition1) & (condition2))` when you want to index an array based on multiple conditions.

In [None]:
## CODE GOES HERE
psd, freqs = mne.time_frequency.psd_array_welch(
    x=raw.get_data(), sfreq=raw.info["sfreq"], n_fft=2048
)
plot_freqs_idx = np.where((freqs >= 5) & (freqs <= 50))[0]
plot_psd(psd=psd[:, plot_freqs_idx], freqs=freqs[plot_freqs_idx])

##### Summary of PSD computation

As you can see, the `compute_psd()` methods of `Raw` and `Epochs` objects are very convenient ways of computing PSDs, with equivalent standalone functions for computations on arrays.

However, MNE also offers tools for more advanced time-frequency analyses based on epoched data. These include time-frequency representations ([TFRs](https://mne.tools/stable/documentation/glossary.html#term-tfr)) based on the multitaper, Morlet wavelet, or Stockwell transformation methods:
- Multitaper:
    - From `Epochs` objects: [`mne.time_frequency.tfr_multitaper()`](https://mne.tools/stable/generated/mne.time_frequency.tfr_multitaper.html)
    - From arrays: [`mne.time_frequency.tfr_array_multitaper()`](https://mne.tools/stable/generated/mne.time_frequency.tfr_array_multitaper.html)
- Morlet wavelet:
    - From `Epochs` objects: [`mne.time_frequency.tfr_morlet()`](https://mne.tools/stable/generated/mne.time_frequency.tfr_morlet.html)
    - From arrays: [`mne.time_frequency.tfr_array_morlet()`](https://mne.tools/stable/generated/mne.time_frequency.tfr_array_morlet.html)
- Stockwell transformation:
    - From `Epochs` objects: [`mne.time_frequency.tfr_stockwell()`](https://mne.tools/stable/generated/mne.time_frequency.tfr_stockwell.html)
    - From arrays: [`mne.time_frequency.tfr_array_stockwell()`](https://mne.tools/stable/generated/mne.time_frequency.tfr_array_stockwell.html)

Time-frequency analyses are discussed in more detail here: https://mne.tools/stable/auto_tutorials/time-freq/20_sensors_time_frequency.html#time-frequency-analysis-power-and-inter-trial-coherence

### Part 3 - Spectral filtering to remove artefacts

Spectral filtering is not only useful for isolating activity at some frequencies of interest, but it can also be used to remove artefacts from the data.

The ability to remove technical artefacts was previously mentioned in the context of notch filtering line noise, but biological artefacts such as cardiac activity can also be identified using spectral filtering, and subsequently removed (see e.g. https://labeling.ucsd.edu/tutorial/labels).

Cardiac artefacts can be clearly seen in the MEG channels of MNE's sample data.

In [None]:
# Load the sample data
raw = mne.io.read_raw_fif(
    os.path.join(mne.datasets.sample.data_path(), "MEG", "sample", "sample_audvis_raw.fif")
)
raw.del_proj()  # delete existing PCA projections

# Pick some channels with strong artefacts and plot them
artefact_picks = [152, 155, 158, 164, 167, 170, 272, 275, 278, 284, 287, 290]
raw.plot(order=artefact_picks);

If our desire is to analyse neural data, not removing these non-neural artefacts could of course lead to erroneous conclusions.

Thankfully, MNE has a convenient function for doing just that: [`mne.preprocessing.compute_proj_ecg()`](https://mne.tools/stable/generated/mne.preprocessing.compute_proj_ecg.html).

`compute_proj_ecg()` involves:
- Filtering the data within a given frequency range to isolate the cardiac activity.
- Finding the peaks of cardiac activity.
- Creating epochs around these peaks of activity.
- Using these epochs to create [projection vectors](https://mne.tools/stable/documentation/glossary.html#term-projector) that can be used to minimise the cardiac artefacts in the data.

In [None]:
# Find projections to minimise cardiac artefacts
projs, _ = mne.preprocessing.compute_proj_ecg(raw=raw)

# Apply projections to the data and plot the cleaned data
raw.add_proj(projs=projs)
raw.plot(order=artefact_picks);

An equivalent function exists for removing eye movement artefacts: [`mne.preprocessing.compute_proj_eog()`](https://mne.tools/stable/generated/mne.preprocessing.compute_proj_eog.html).

## Conclusion

Spectral filtering is an important part of many analyses in neuroscience, involving e.g. the extraction of activity at specific frequencies of interest and the removal of artefacts. The `filter()` and `notch_filter()` methods of `Raw` and `Epochs` objects provide convenient ways of performing such filtering, with equivalent standalone functions for working with arrays of data.

Spectral activity can be represented as PSDs, which can be computed using the `compute_psd()` methods of `Raw` and `Epochs` objects, or the equivalent standalone functions for computations on arrays. More advanced spectral analyses are also offered in the form of TFR computations, with MNE's [Time-Frequency module](https://mne.tools/stable/api/time_frequency.html) also offering several other useful tools, such as for computing cross-spectral densities (CSDs).

## Additional resources

MNE tutorial on spectral analysis: https://mne.tools/stable/auto_tutorials/time-freq/20_sensors_time_frequency.html

MNE tutorial on `Spectrum` and `EpochsSpectrum` classes: https://mne.tools/stable/auto_tutorials/time-freq/10_spectrum_class.html

Video introducing the Fourier transform with some very nice visualisations:
https://youtu.be/spUNpyF58BY?si=hUC2zG8dG6Zah8tP