# 5. Fit Optimization

In this tutorial, waveform motifs will be used to optimze spectral parametrization.

In [None]:
from itertools import combinations
import numpy as np
from scipy.stats import ttest_ind
import matplotlib.pyplot as plt

from neurodsp.sim import sim_variable_oscillation, sim_powerlaw
from neurodsp.utils.norm import normalize_sig
from neurodsp.spectral import compute_spectrum, trim_spectrum
from neurodsp.plts import plot_time_series

from fooof import FOOOF
from fooof.sim.gen import gen_periodic

from ndspflow.motif import Motif

### Simulate Oscillations

Asymmetrical sine waves are with rise-decay symmetries of .25, .5, and .75, each at a different frequency (8, 10, or 12 hz), are simulated below.

The resulting oscillations are then combined with a powerlaw signal, to produce a 1/f slope in frequency space.

In [None]:
# Simulate different rdsyms at 8, 10, and 12 hz
fs = 1000

sig_osc = sim_variable_oscillation(None, fs, freqs=np.repeat([8, 10, 12], 50), cycle='asine',
                                   rdsym=np.repeat([.25, .5, .75], 50))

sig_pl = sim_powerlaw(1, len(sig_osc), exponent=-2)

sig = normalize_sig(((.8 * sig_osc) + (1 * sig_pl)), mean=0, variance=1)

# Plot
times = np.arange(0, len(sig)/fs, 1/fs)
plot_time_series(times, sig)

### SpecParam
Next, the spectrum of the simulated timeseries is taken. The spectrum is then parameterized, using default, liberal settings. Later these settings will be updated base on a waveform shape analysis.

In [None]:
# Get spectrum
freqs, powers = compute_spectrum(sig, fs, f_range=(0, 100))
freqs, powers = trim_spectrum(freqs, powers, (1, 100))

# SpecParam
fm = FOOOF()

fm.fit(freqs, powers, freq_range=(1, 100))

fm.plot(plot_peaks='shade')

### Motifs

Waveform motifs are found. The three asymmetrical waves are identified in a single ~10hz oscillation peak, waveforms are untangled via k-means clustering.

Then, the additional peaks (i.e. harmonics due to asymmetry), will be ignored. This will allows finer tuning of detected oscillatory spectral peaks.

In [None]:
motif = Motif(min_clust_score=0.1)
motif.fit(fm, sig, fs)
motif.plot()

In [None]:
f_ranges = [results.f_range for results in motif.results if isinstance(results.f_range, tuple)]

# Remove peaks with no associated cycles
keep_peaks = []

for idx, cf in enumerate(fm.peak_params_[:, 0]):
    for f_range in f_ranges:
        if cf >= f_range[0] and cf <= f_range[1]:
            keep_peaks.append(idx)
            
fm.peak_params_ = fm.peak_params_[keep_peaks]
fm.gaussian_params_ = fm.gaussian_params_[keep_peaks]

# Regenerate fits
fm._peak_fit = gen_periodic(fm.freqs, np.ndarray.flatten(fm.gaussian_params_))
fm.fooofed_spectrum_ = fm._peak_fit + fm._ap_fit

# Plot
fm.plot(plot_peaks='shade')

### Separate Oscillations in the Frequency Domain

Three oscillations exists at 8, 10, and 12 hz. We will now use run a grid search on a range of parameters, and show how two different optimization metrics may vary results.

In [None]:
n_motifs = np.max(motif.results[0].labels) + 1

# Sort periods by motif cluster membership
periods = []
for motif_idx in range(n_motifs):
    row_idxs = np.where(motif.results[0].labels == motif_idx)[0]
    periods.append(motif.results[0].df_features.iloc[row_idxs]['period'])
    
# Determine significants differences between motifs center frequencies
#   using pariwise t-tests
cfs_exp = []
for pair in list(combinations(range(n_motifs), 2)):
    pair_a = fs / periods[pair[0]]
    pair_b = fs / periods[pair[1]]
    if ttest_ind(pair_a, pair_b).pvalue < .05:
        cfs_exp.append(pair_a.mean())
        cfs_exp.append(pair_b.mean())  
        
cfs_exp = np.unique(cfs_exp)

In [None]:
import warnings
warnings.filterwarnings("ignore")

# Determine frequency range
cf_orig = fm.get_params('peak', 'CF')
bw_orig = fm.get_params('peak', 'BW')

freq_range = (cf_orig - bw_orig, cf_orig + bw_orig)

# Determine bandwidth grid
n_steps = 100

upper_bounds = np.linspace(1, bw_orig, n_steps+1)[1:]

results = np.zeros((len(upper_bounds), 3))

fm_re = []

# Run grid search
for idx, upper in enumerate(upper_bounds):
    
    results[idx][0] = idx
    
    _fm = FOOOF(peak_width_limits=(1, upper), verbose=False)

    _fm.fit(freqs, powers, freq_range=freq_range)
    
    cfs = _fm.get_params('peak', 'CF')
    
    if len(_fm.get_params('peak')) != len(cfs_exp):
        results[idx][1] = np.nan
        results[idx][2] = np.nan
    else:
        mae = np.abs(cfs_exp-cfs).mean()
        results[idx][1] = mae
        results[idx][2] = _fm.r_squared_
        
    fm_re.append(_fm)
    
# Remove nans (i.e. fits with incorrect number of peaks)
results = np.delete(results, np.isnan(results[:, 1]), axis=0)

In [None]:
highest_rsq = int(results[np.argmin(results[:, 2])][0])

fm_rsq = fm_re[np.argmax(results[:, 2])]
fm_rsq.plot(plot_peaks='shade')

plt.title('R^2 Optimized', size=20)
plt.show()

In [None]:
lowest_mae = int(results[np.argmin(results[:, 1])][0])

fm_mae = fm_re[lowest_mae]
fm_mae.plot(plot_peaks='shade')

plt.title('CF MAE Optimized', size=20)
plt.show()