In [315]:
import matplotlib.pyplot as plt
import numpy as np

from neurodsp.sim import sim_oscillation
from neurodsp.spectral import compute_spectrum

from statsmodels.tsa.stattools import acf

from timescales.sim import sim_lorentzian, sim_exp_decay, sim_branching
from timescales.fit import PSD

from scipy.optimize import minimize
from functools import partial

In [320]:
class DualTS:
    """Dual-objective timescales learning."""
    
    def __init__(self, fs, psd_func, acf_func, loss_fn, param_est, param_bounds):
        """Initialize."""

        # Sampling rate
        self.fs = fs
        
        # PSD and ACF estimate functions
        self.psd_func = psd_func
        self.acf_func = acf_func

        # Loss
        self.loss_fn = loss_fn

        # Estimate and bounds
        self.param_est = param_est
        self.param_bounds = param_bounds
        
        # Results
        self.psd_ = None
        self.freqs_ = None
        self.acf_ = None
        self.lags_ = None
        self.fk = None
        
        
    def fit(self, X):
        """Fit."""

        # Init results array
        self.fk = np.zeros((len(X)))
        
        # Normalize X
        X = (X - X.mean(axis=1)[:, None]) / X.std(axis=1)[:, None]
        
        # Compute PSD and ACF transforms
        for i, x in enumerate(X):
                
            _freqs, _psd = self.psd_func(x)
            _lags, _acf = self.acf_func(x)
            
            if i == 0:
                
                self.psd_ = np.zeros((len(X), len(_psd)))
                self.acf_ = np.zeros((len(X), len(_acf)))
                
                self.freqs_ = _freqs
                self.lags_ = _lags
                
                    
            f_partial = partial(self.loss_fn, self.freqs_, self.lags_, self.psd_[i], self.acf_[i], self.fs)
            self.psd_[i] = _psd
            self.acf_[i] = _acf

            res = minimize(f_partial, self.param_est, bounds=self.param_bounds)
 
            self.fk[i] = res.x[0]


def loss_fn(freqs, lags, powers, acf, fs, params):
    """Function to minimize."""

    # Unpack
    fk, psd_exponent, psd_offset, psd_constant, acf_amp, acf_offset = params
    
    # Forward models
    powers_fit = sim_lorentzian(freqs, fk, psd_exponent, psd_offset, psd_constant)
    acf_fit = sim_exp_decay(lags, fs, 1/(2*np.pi*fk), acf_amp, acf_offset)

    # MSE
    mse_psd = ((powers - powers_fit)**2).mean()
    mse_acf = ((acf - acf_fit)**2).mean()
    mse = (mse_psd + mse_acf) / 2
    
    return mse

In [341]:
# Simulate
fk = 10
tau = 1/(2*np.pi*fk)

n_seconds = 2
fs = 1000

sigs = np.zeros((100, int(n_seconds*fs)))

for i in range(100):
    sig_ap = sim_branching(n_seconds, fs, tau, 1000, mean=0, variance=1)
    sig = sig_ap
    sigs[i] = (sig_ap - sig_ap.mean()) / sig_ap.std()

# Initial parameter estimates
fk = 20. # set this poorly on purpose
psd_exponent = 2.
psd_offset = 0.
psd_constant = 0.
acf_amp = 1.
acf_offset = 0.

param_est = (fk, psd_exponent, psd_offset, psd_constant, acf_amp, acf_offset)
param_bounds = (
    (1, 100),
    (1, 3),
    (1e-3, 1e3),
    (0, 10),
    (0, 2),
    (-1, 1)
)

In [339]:
# ACF and PSD functions
def compute_acf(x):
    corrs = acf(x, nlags=500)
    lags = np.arange(len(corrs))
    return lags, corrs

psd_partial = partial(compute_spectrum, fs=fs, f_range=(1, 100))

In [340]:
# Fit dual-objective model
dts = DualTS(fs, psd_partial, compute_acf, loss_fn, param_est, param_bounds)
dts.fit(sigs)

# Fit standard PSD
freqs, powers = psd_partial(sigs)

psd = PSD()
psd.fit(freqs, powers, n_jobs=10)

# Compare knee freq mse
print("Dual Object Knee Freq MSE: ", ((dts.fk - 10)**2).mean())
print("PSD Fit Knee Freq MSE: ", ((psd.knee_freq - 10)**2).mean())

Dual Object Knee Freq MSE:  9.499690861763519
PSD Fit Knee Freq MSE:  37.400582203036215
