# Lab 3: Playing with SOMATA

In [None]:
# author: Mingjian He <email:mh1@stanford.edu>
# License: BSD (3-clause)

# Import packages used throughout the three lab exercises
import numpy as np
import pymatreader
from pprint import pprint
from scipy import signal
import matplotlib.pyplot as plt

# Global setting of matplotlib figures
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "Helvetica",
    "figure.constrained_layout.use": True,
    "savefig.dpi": 300
})

## Exercise 1: getting familiar with the syntaxes
Learning objective: understand common syntaxes to manipulate SOMATA basic models

In [None]:
# Import the four different basic models in SOMATA
from somata.basic_models import StateSpaceModel as Ssm
from somata.basic_models import GeneralSSModel as Gen
from somata.basic_models import OscillatorModel as Osc
from somata.basic_models import AutoRegModel as Arn

### 1.1 `print()` and `append()`

These constructor methods can be called without any arguments.

In [None]:
s1 = Ssm()  # create an empty state-space model instance without any parameters
s1  # this returns the __repr__() method output for a Python object

Invoking `print()` gives helpful summary information about the state-space model object.

This is one of the most heavily used methods in SOMATA.

In [None]:
print(s1)

Appending one model to another model augments the state-space parameters in block-diagonal form.

In [None]:
s1 = Ssm(F=1)
s1

The first model calling `append()` gets modified in place.

In [None]:
s2 = Ssm(F=0.9)
s1.append(s2)
print(s1)
print('s1.F', s1.F)

### 1.2 The concept of `components`

An alternative way to think about this new state-space model `s1` is that it is a state-space model with two components.

The first component has a univariate state in the hidden state space, so does the second component.

We can try to recreate such state-space model using the `components` argument.

In [None]:
g1 = Gen(F=1)
g2 = Gen(F=0.9)
s1 = Ssm(components=[g1, g2])
s1

Notice that the `__repr__()` method now returns `Ssm(2)<....>`, which indicates that this state-space model now has two components (`ncomp=2`).

Utilizing `components` to represent state-space models is advantageous in many ways. To list a few:
1. It gives a much clearer understanding of the structure of the state-space model in a quick glance.
2. It can be used for updating model parameters in parallel during the M-step of EM algorithm due to the convenient block-diagonal structure.
3. It provides a breadboard for different combinations of state-space model components.

In [None]:
print(s1)
print('s1.F', s1.F)
print('s1.G', s1.G)

The four characters in `<>` are the last four digits of the memory address ID of the model instance, just in case it becomes difficult to keep track of multiple models.

Notice how `g1` and `g2` have become the component models within `s1`, indicated by the same memory addresses.

In [None]:
print(g1)
print(s1.components[0])

However, when `s1` undergoes EM, the component models `g1` and `g2` will not get updated automatically.

This is so that computations are not slowed down by unnecessary `setattr()` calls. If you would like to map parameters in `s1` back to its components, use `fill_components()`.

If you want to break the memory address link, simply cascade `copy()` to the model constructor call.

In [None]:
s2 = Ssm(components=[g1, g2]).copy()
print(s2)

Since `StateSpaceModel` is the parent state-space model class object for the other basic models in SOMATA, itself cannot be a component.

This constraint is to avoid ambiguous handling of component state-space models. All other basic models can be components.

In [None]:
s1 = Ssm(F=1)
s2 = Ssm(F=0.9)
# s3 = Ssm(components=[s1, s2])  # this won't work by design

### 1.3 `GeneralSSModel` class

`GeneralSSModel` prints the same attributes as `StateSpaceModel`.

In [None]:
g1 = Gen(F=1)
g1

Notice that the `components` attribute is automatically populated with a "space-holder" component model, which did not happen for `StateSpaceModel`.

In [None]:
print(g1)

This component model does not have any parameter unlike in the constructor example above using the `components=` argument.

**It is important to emphasize that `components` models should not be used directly throughout SOMATA!**

Regardless of carrying parameters or not, they act as space holders to organize the overarching state-space model rather than being used for computations themselves.

In [None]:
print(g1.components[0])

### 1.4 `OscillatorModel` class

`OscillatorModel` has special attributes that are printed to help quickly understand the oscillator model.

In [None]:
o1 = Osc()
o1

In [None]:
print(o1)

This also means that we can create an `OscillatorModel` using more concise parameter arguments.

The default $\sigma^2=3$ gets filled in when `sigma2` isn't provided, but you would likely want to set it explicitly.

In [None]:
o1 = Osc(a=0.95, w=np.pi/10)
print(o1)

Notice how the `freq` display changes to the physical `Hz` unit when sampling rate (`Fs`) is also provided.

In [None]:
o1 = Osc(a=0.95, w=np.pi/10, Fs=250)
print(o1)

We can directly provide rotating frequency in `Hz` along with `Fs` to create an `OscillatorModel`, which is very useful in practice.

In [None]:
o1 = Osc(a=0.99, freq=10, Fs=100)
print(o1)

### 1.5 `AutoRegModel` class

`AutoRegModel` has a more intelligent `__repr__()` since the AR order is a more relevant information when there is only one component.

In [None]:
a1 = Arn()
a1

In [None]:
a2 = Arn(coeff=[0.9, -0.5, 0.3])
a2

In [None]:
print(a2)

You can also create a multi-component AR model by passing a list of lists of parameters.

In [None]:
a3 = Arn(coeff=[[0.9, -0.5], [0.3, 0.1]])
a3

Like `OscillatorModel`, the default $\sigma^2=3$ gets filled in, but you would likely want to set it explicitly.

In [None]:
print(a3)

Note that `AutoRegModel` does not support multivariate AR models because the display gets too clunky.

Use `StateSpaceModel` to hold high-dimensional state-space models instead.

In [None]:
# mva1 = Arn(F=[[0.9, -0.5], [0.3, 0.1]], Q=[[0.9, 0.2], [0.2, 0.5]])  # this won't work by design
mva1 = Ssm(F=[[0.9, -0.5], [0.3, 0.1]], Q=[[0.9, 0.2], [0.2, 0.5]])
print(mva1)

### 1.6 State-space models with heterogeneous `components`

Now let's try something fancy. What if we want to build a state-space model with both an AR model and an oscillator model as its components?

The first way is to pass the AR and oscillator models as `components` into a new state-space model.

In [None]:
a1 = Arn(coeff=[0.9, -0.5, 0.3])
o1 = Osc(a=0.99, freq=10, Fs=100)
s1 = Ssm(components=[a1, o1])
print(s1)

The second way is to create the two models and then use the `concat_()` method instead of `append()`.

In [None]:
a1 = Arn(coeff=[0.9, -0.5, 0.3])
o1 = Osc(a=0.99, freq=10, Fs=100)
s1 = a1.concat_(o1)  # this returns a new model instance instead of modifying a1 in place like in append()
print(s1)

Notice that unlike in the first way, the memory addresses of the component models in `s1` are different from those of `a1` and `o1`.

In [None]:
print(a1)
print(o1)

### 1.7 Automatic initialization of `components`

The `components` argument can also be used to automatically construct a specific model type as components from parameters.

In [None]:
s1 = Ssm(components='Arn', F=0.9, Q=1)
print(s1)

We can use this syntax to automatically parse block-diagonal parameters into multiple components. 

For example, two AR1 models with filled parameters are created below as `components`. We will later see that this is convenient for `OscillatorModel` as well.

In [None]:
s1 = Ssm(components='Arn', F=[[0.9, 0], [0, 0.8]], Q=[[1, 0], [0, 1]])
print(s1)
print(s1.components[0])

### 1.8 Model stacking and arrays of models

Unlike `append()` or `concat_()`, adding one model to another model will "stack" in the third dimension of mismatched parameters (`F` in this case) and increase `nmodel` by 1.

In [None]:
s1 = Ssm(F=1, Q=0.5)
s2 = Ssm(F=0.9, Q=0.5)
s3 = s1 + s2
s3

Parameters that are matched between state-space models (`Q` in this case) are kept intact to be memory-efficient.

In [None]:
print(s3)
print('s3.F', s3.F)
print('s3.Q', s3.Q)

If we multiply two models with different parameters, models with combinations of their parameters are stacked together.

In [None]:
s1 = Ssm(F=1, Q=0.5, R=0.3)
s2 = Ssm(F=0.9, Q=0.5, R=0.5)
s3 = s1 * s2
print(s3)
print('s3.F', s3.F)
print('s3.Q', s3.Q)
print('s3.R', s3.R)

Raising a model to any positive exponential power will form combinations across its own parameters.

In [None]:
s1 = Ssm(F=1, Q=0.1, R=0.3)
s2 = Ssm(F=0.9, Q=0.2, R=0.5)
s3 = s1 + s2
print(s3)
s4 = s3 ** 2
print(s4)
print('s4.F', s4.F)
print('s4.Q', s4.Q)
print('s4.R', s4.R)

`s3` and `s4` are essentially condensed models representing multiple alternative state-space models, which differ in some parameters.

We can turn them into arrays of state-space models for easy looping with the `stack_to_array()` method.

In [None]:
print(s4.stack_to_array())
print(type(s4.stack_to_array()))

for model_no, model in enumerate(s4.stack_to_array()):
    print(f'Model {model_no}: F =', model.F, 'Q =', model.Q, 'R =', model.R)

### 1.9 Now that you have learned how to work with SOMATA basic models, let's give it a try

**Create a seasonal adjustment model with the following components:**

1. A random walk of variance $0.8$, i.e., $x_t = x_{t-1} + w_t$ and $w_t \sim \mathcal{N}(0, 0.8)$.
2. An oscillation with damping factor of 0.96, central frequency of $20$ Hz, and mean-zero process noise of variance $1.2$.
3. An autoregressive process with coefficients $[0.7, -0.35, 0.5]$ and process noise of variance $0.3$.
4. An unnamed dynamic with transition matrix $\begin{bmatrix} 0.3 & 0.4\\ 0.5 & 0.6 \end{bmatrix}$, state noise covariance $\begin{bmatrix} 0.6 & 0.23\\ 0.23 & 0.4 \end{bmatrix}$, and observation matrix $\begin{bmatrix} 0.112 & 3.58 \end{bmatrix}$.

The seasonal adjustment model is used to track a stream of univariate observations collected at $250$ Hz sampling rate, and you estimate the observation noise variance to be at $1.0$.

In [None]:
m1 = Arn(F=1, Q=0.8)
m2 = Osc(a=0.96, freq=20, sigma2=1.2, Fs=250)
m3 = Arn(coeff=[0.7, -0.35, 0.5], sigma2=0.3)
m4 = Gen(F=[[0.3, 0.4], [0.5, 0.6]], Q=[[0.6, 0.23], [0.23, 0.4]], G=[0.112, 3.58])

season_model = Ssm(components=[m1, m2, m3, m4], R=1.0)
print(season_model)

## Exercise 2: running EM algorithm with state-space models
Learning objective: understand how to do state inference and parameter estimation on SOMATA basic models

In [None]:
# Import two of the basic models in SOMATA
from somata.basic_models import StateSpaceModel as Ssm
from somata.basic_models import OscillatorModel as Osc

### 2.1 Load the data and visualize

Load a 10-s example time series data and create a state-space model.

Notice that `components=None`. We will come back to this later.

In [None]:
# Load example state-space model parameters from a .mat file
ssm_params = pymatreader.pymatreader.read_mat('example_ssm_parameters.mat')
# pop out the dunder keys from the dictionary
[ssm_params.pop(x) for x in ['__header__', '__version__', '__globals__']]

s1 = Ssm(F=ssm_params['F'], Q=ssm_params['Q'], mu0=ssm_params['mu0'], S0=ssm_params['S0'],
         G=ssm_params['G'], R=ssm_params['R'], y=ssm_params['y'], Fs=ssm_params['Fs'])
print(s1)

Plot the time trace. (Optional): Use your spectral estimation method from previous labs to examine the spectrum of this data.

In [None]:
plt.plot(s1.y.T)
plt.xlabel('Time (samples)')
plt.ylabel('Arbitrary units')
_ = plt.title('Example time series data')

The parameters in `ssm_params` are already set up as two independent oscillators. Let's create a more informative state-space model object using `OscillatorModel`.

In [None]:
o1 = Osc(F=ssm_params['F'], Q=ssm_params['Q'], mu0=ssm_params['mu0'], S0=ssm_params['S0'],
         G=ssm_params['G'], R=ssm_params['R'], y=ssm_params['y'], Fs=ssm_params['Fs'])  # automatically parses oscillators
o1

This state equation contains two oscillators in the hidden state space: one at $1.48$ Hz and one at $12.66$ Hz. These parameters are estimated after 50 iterations of EM.

In [None]:
print(o1)

### 2.2 State inference with Kalman filtering and smoothing

Performing Kalman filtering and smoothing with SOMATA basic models is very simple.

In [None]:
kalman_results = o1.kalman_filt_smooth(return_dict=True)
_ = [print(x) for x in kalman_results.keys()]

When we are doing EM, we don't need all these outputs so we can get a more concise dictionary.

In [None]:
em_kalman_results = o1.kalman_filt_smooth(EM=True)
_ = [print(x) for x in em_kalman_results.keys()]

Let's look at the Kalman filtered estimates.

In [None]:
fig, ax = plt.subplots(4, 1, figsize=(10, 10))
ax[0].plot(kalman_results['x_t_t'][0, :], label='Oscillator 1 - Real')
ax[1].plot(kalman_results['x_t_t'][1, :], label='Oscillator 1 - Imaginary')
ax[2].plot(kalman_results['x_t_t'][2, :], label='Oscillator 2 - Real')
ax[3].plot(kalman_results['x_t_t'][3, :], label='Oscillator 2 - Imaginary')
for axx in ax:
    axx.legend(loc='upper right')
    axx.set_ylabel('Arbitrary units')
ax[-1].set_xlabel('Time (samples)')
_ = ax[0].set_title('Kalman filtering estimates')

If we have multiple models stacked together or an array of models, we can run Kalman filtering and smoothing in parallel with `par_kalman()`.

In [None]:
o2 = o1.copy()
o2.freq = np.array([2, 13])  # manually change the oscillator rotation center frequencies
o2.fill_ssm_param()  # this propagates the changes to the state-space model parameters (F, Q)
print(o2)
print('o1.F', o1.F)
print('o2.F', o2.F)

We can create a stacked model then run Kalman filtering and smoothing across all underlying models.

In [None]:
o3 = o1 + o2  # this creates a stacked model
print(o3)

# We can directly call the par_kalman() method on the stacked model
par_kalman_results = Ssm.par_kalman(o3, return_dict=True)

Each of the `_all` keys points to a list of corresponding results for the array of models.

In [None]:
_ = [print(x) for x in par_kalman_results.keys()]
par_kalman_results['x_t_n_all']

Equivalently, we can create an array of models and call the `par_kalman()` method.

This syntax is quite flexible, since the array can contain arbitrary SOMATA state-space models, as long as Kalman filtering and smoothing can be performed.

One default check is that these models should all have the same observed data `y` to discourage erroneous model comparisons.

If one really intends to run with different data, pass in `skip_check_observed=True` but be mindful of estimation differences.

In [None]:
model_array = o3.stack_to_array()  # equivalent to model_array = [o1, o2]
par_kalman_results = Ssm.par_kalman(model_array, return_dict=True, skip_check_observed=False)
_ = [print(x) for x in par_kalman_results.keys()]

### 2.3 Run EM algorithm on state-space models

It is very simple to run the vanilla EM algorithm with maximum likelihood estimation (MLE).

However, there are a few things to take note of:
1. The `**` syntax simply passes the dictionary key-value pairs as keyword arguments into the `m_estimate()` method.
2. The `EM=True` flag is necessary since otherwise there are extra arguments created from the full dictionary of Kalman results.
3. The model calling `m_estimate()` is modified in place, including all parameters but not the component models.
4. Both `kalman_filt_smooth()` and `m_estimate()` can take in different observed data using the `y=` argument, which is useful for handling data segments.
4. While a plain model, like `s1` above, can run Kalman filtering and smoothing, a model must have `components` in order to invoke the `m_estimate()` method.
5. The `m_estimate()` method looks into `components` to retrieve component-specific update methods and to update matrix blocks in parallel.

In [None]:
o1.m_estimate(**o1.kalman_filt_smooth(EM=True))

Therefore, in order to follow the oscillator specific update equations, one needs to use the `OscillatorModel` constructor from the beginning.

Or one can pass in `components='Osc'` when using the `StateSpaceModel` constructor. This ensures that the component models are of `OscillatorModel` type.

In [None]:
# s1.m_estimate(**s1.kalman_filt_smooth(EM=True))  # this won't work by design

s1 = Ssm(F=ssm_params['F'], Q=ssm_params['Q'], mu0=ssm_params['mu0'], S0=ssm_params['S0'],
         G=ssm_params['G'], R=ssm_params['R'], y=ssm_params['y'], Fs=ssm_params['Fs'], components='Osc')
print(s1)
s1.m_estimate(**s1.kalman_filt_smooth(EM=True))  # this works because component models are OscillatorModel instances
print('s1.Q', s1.Q)  # this is diagonal

A different set of update rules will be employed if a different component model class is assumed.

In [None]:
s1 = Ssm(F=ssm_params['F'], Q=ssm_params['Q'], mu0=ssm_params['mu0'], S0=ssm_params['S0'],
         G=ssm_params['G'], R=ssm_params['R'], y=ssm_params['y'], Fs=ssm_params['Fs'], components='Gen')
print(s1)
s1.m_estimate(**s1.kalman_filt_smooth(EM=True))  # this uses the full matrix update rules for the GeneralSSModel instances
print('s1.Q', s1.Q)  # no longer diagonal

One can easily run multiple EM iterations in a for loop. Let's start with some arbitrary guesses of parameters and see how EM increases the log-likelihood.

In [None]:
o1 = Osc(a=[0.99, 0.99], freq=[1, 10], y=ssm_params['y'], Fs=ssm_params['Fs'])
print('Before EM: o1.freq', o1.freq)
logL = []
for _ in range(50):  # run 50 iterations of EM
    kalman_results = o1.kalman_filt_smooth(EM=True)
    logL.append(kalman_results['logL'].sum())  # sum over time points
    o1.m_estimate(**kalman_results)
print('After EM: o1.freq', o1.freq)

plt.plot(logL)
plt.xlabel('EM iteration')
plt.ylabel('Log-likelihood')
_ = plt.title('Log-likelihood over EM iterations')

Priors can be used as long as they are implemented in the corresponding `_m_update_...()` methods of the component model class.

Details of these priors can be found in the implementations of model parameter update methods called within `m_estimate()` during the EM algorithm.

One can also control which parameters get updated using the `update_param=` or `keep_param=` arguments. Examples are omitted here, but you are encouraged to explore their usage.

In [None]:
priors = o1.initialize_priors()  # a list of priors one dictionary per component
pprint(priors)
o1.m_estimate(**o1.kalman_filt_smooth(EM=True), priors=priors)  # maximum a posteriori (MAP) estimation

### 2.4 Comparison with Hilbert transform after bandpass filtering

A common approach neuroscientists use to extract the amplitude and phase of neural oscillations is by bandpass filtering followed by Hilbert transform.

Let's compare the results of this approach to state-space oscillator modeling of the ~$12$ Hz oscillation.

In [None]:
# First select the order of the lowest order filter to achieve the desired filter properties
N, Wn = signal.buttord(wp=[10, 14], ws=[9, 15], gpass=1, gstop=50, fs=ssm_params['Fs'])

# Then we design the bandpass filter
sos = signal.butter(N, Wn, btype="bandpass", output='sos', fs=ssm_params['Fs'])

# Visualize the filter frequency response - it's a good habit to always check your filter design
w, h = signal.sosfreqz(sos, fs=ssm_params['Fs'])
fig, ax = plt.subplots()
ax.plot(w, 20 * np.log10(np.maximum(np.abs(h), 1e-3)))
ax.set_title('Butterworth filter frequency response')
ax.set_xlabel('Frequency [Hz]')
ax.set_ylabel('Amplitude [dB]')
ax.margins(0, 0.1)
ax.grid(which='both', axis='both')
_ = ax.set_xlim([0, ssm_params['Fs']/2])

Apply the filter we designed with the above frequency response and visualize the filtered time series.

In [None]:
# This preserves the phase by filtering forward and backward
y_filt = signal.sosfiltfilt(sos, np.squeeze(ssm_params['y']))

# Visualize the filtered time series
plt.plot(y_filt)
plt.xlabel('Time (samples)')
plt.ylabel('Arbitrary units')
_ = plt.title('Filtered data in 10-14 Hz')

Next, we apply Hilbert transform to obtain the complex analytic signal. Also recompute the smoothing estimates from state-space modeling.

In [None]:
y_hilb = signal.hilbert(y_filt)

# Kalman smoothing using saved parameters that have been learned after 50 EM iterations
kalman_results = Osc(**ssm_params).kalman_filt_smooth(return_dict=True)
y_ssm = kalman_results['x_t_n'][2, :] + 1j * kalman_results['x_t_n'][3, :]  # form a complex signal

Compare the analytic signal to the state-space modeling smoothing estimates.

In [None]:
fig, ax = plt.subplots(2, 1, figsize=(10, 5))
ax[0].plot(y_ssm.real, label='Oscillator 2 - Real')
ax[0].plot(y_hilb.real, label='Analytic signal - Real')
ax[1].plot(y_ssm.imag, label='Oscillator 2 - Imaginary')
ax[1].plot(y_hilb.imag, label='Analytic signal - Imaginary')

for axx in ax:
    axx.legend(loc='lower left')
    axx.set_ylabel('Arbitrary units')
ax[-1].set_xlabel('Time (samples)')
_ = ax[0].set_title('Kalman smoothing estimates and Hilbert analytic signal')

We can compare the extracted oscillation instantaneous amplitudes and phases as well.

In [None]:
fig, ax = plt.subplots(2, 1, figsize=(10, 5))
ax[0].plot(np.abs(y_ssm), label='Oscillator 2')
ax[0].plot(np.abs(y_hilb), label='Analytic signal')
ax[1].plot(np.unwrap(np.angle(y_ssm)), label='Oscillator 2')  # unwrap phase to see the difference more clearly
ax[1].plot(np.unwrap(np.angle(y_hilb)), label='Analytic signal')

for axx in ax:
    axx.legend(loc='upper left')
ax[0].set_ylabel('Amplitude (a.u.)')
ax[1].set_ylabel('Phase (rad)')
ax[-1].set_xlabel('Time (samples)')
_ = ax[0].set_title('Amplitude and phase comparisons')

The take-home message is that the best "Hilbert transform after bandpass filtering" can do is to (almost) par the state-space modeling estimates.

However, the analytic approach with Hilbert transform has a few common pitfalls:
- The oscillation needs to be strong and to dominate the filtered frequency range.
- The filter needs to be carefully designed and tuned to the underlying oscillation.
- Even then, the instantaneous amplitude of filtered signal is often smoother than reality.

On the other hand, state-space modeling provides reliable estimates after EM learning, which adjusts to the observed data.

## Exercise 3: searching for oscillations in data
Learning objective: learn to use and interpret the oscillator search methods in SOMATA

In [None]:
# Import the oscillator basic model in SOMATA
from somata.basic_models import OscillatorModel as Osc

# Import the two oscillator search methods in SOMATA
from somata.oscillator_search import IterativeOscillatorModel as IterOsc
from somata.oscillator_search import DecomposedOscillatorModel as DecOsc

### 3.1 Simulate oscillatory data with state-space models

First create an oscillator model instance with two oscillators.

In [None]:
o1 = Osc(a=[0.996, 0.95], freq=[0.1, 10], sigma2=[0.4, 0.2], R=1.2, Fs=100)
print(o1)

The idea hasn't occurred to me until now that we need a method to simulate data from state-space models.

Perhaps I have been spoiled by abundant experimental data in the past. Anyways, let's write one together.

In [None]:
def simulate_data(model, duration):
    """
    Simulate data from a state-space model.

    Parameters
    ----------
    model : StateSpaceModel
        A SOMATA state-space model instance.
    duration : int
        Simulation duration in seconds.

    Returns
    -------
    x : ndarray
        Latent states.
    y : ndarray
        Simulated observed data.
    """
    # total number of samples in the time series
    T = int(duration * model.Fs)

    # initialize tallies for the latent states and observations
    x = np.zeros((model.nstate, T + 1))
    y = np.zeros((model.nchannel, T))

    # initial state at t=0
    x[:, 0] = np.random.multivariate_normal(np.squeeze(model.mu0), model.S0)

    # iterate through the rest of time points and generate observations
    for ii in range(1, T + 1):
        x[:, ii] = model.F @ x[:, ii - 1] + np.random.multivariate_normal(np.zeros(model.nstate), model.Q)
        y[:, ii - 1] = model.G @ x[:, ii] + np.random.normal(np.zeros(model.nchannel), model.R)

    return x, y

Simulate $10$-s worth of data and keep the ground truth of hidden states as well.

In [None]:
np.random.seed(404)  # for reproducibility
x, y = simulate_data(o1, duration=10)

### 3.2 Iterative oscillator search algorithm (iOsc+)

In [None]:
# Initialize a iterative oscillator model object
io1 = IterOsc(y, o1.Fs, noise_start=None, osc_range=7)
# noise_start determines the frequency above which is used to estimate the observation noise; default: (Nyquist - 20 Hz)
# osc_range is the maximum number of total oscillators; default: 7

Oscillator search model objects have convenient visualization methods implemented.

In [None]:
_ = io1.plot_mtm()  # plot multitaper spectrogram and mean spectrum
_ = io1.plot_trace()  # plot raw time trace

Now run iterations to search for oscillations present in the simulated data while visualizing every iteration.

In [None]:
io1.iterate(freq_res=1, plot_fit=True, verbose=True)  # this is the iOsc+ algorithm
# freq_res is the minimal resolution in Hz from existing frequencies when adding a new oscillator
# plot_fit=True plots innovation spectrum and AR fitting during each iteration
# verbose=True prints parameters throughout the method

Examine the final oscillator model selected by the iOsc+ algorithm.

In [None]:
print(io1.get_knee_osc())

Plot log-likelihood and the selected model (may not be the highest likelihood).

In [None]:
_ = io1.plot_log_likelihoods()

Plot the fitted oscillators in the frequency domain with both theoretical and empirical spectra.

Note that the empirical spectrum is based on smoothing estimates of the hidden states.

In [None]:
_ = io1.plot_fit_spectra()

Plot the time traces of estimated $\mathbf{x}_t$.

In [None]:
_ = io1.plot_fit_traces()

### 3.3 Decomposed oscillator search algorithm (dOsc)

dOsc has a very similar constructor method as iOsc+. The `plot_mtm()` and `plot_trace()` methods work the same way, so we skip them here.

In [None]:
do1 = DecOsc(y, o1.Fs, noise_start=None, osc_range=7)
# noise_start determines the frequency above which is used to estimate the observation noise; default: (Nyquist - 20 Hz)
# osc_range is the maximum number of total oscillators; default: 7

The method call to run the iterations to search for oscillations is also called `iterate()`, albeit with slightly different arguments.

In [None]:
do1.iterate(plot_fit=True)  # this is the dOsc algorithm
# plot_fit=True plots fitted theoretical spectra during each iteration

Also examine the final oscillator model selected by the dOsc algorithm.

In [None]:
print(do1.get_knee_osc())

The same sanity check methods are available for the dOsc algorithm.

In [None]:
_ = do1.plot_log_likelihoods()
_ = do1.plot_fit_spectra()
_ = do1.plot_fit_traces()

### 3.4 Diagnostic residual plots and statistical tests

We can plot the model fitting residuals (one-step prediction error), i.e., $y_t - \hat{y}_{t|t-1}$, and run diagnostic statistical tests to check for auto-correlations and normality.

The same `diagnose_residual()` method applies to both oscillator search methods. Below we demonstrate it for `DecomposedOscillatorModel`.

In [None]:
do1.diagnose_residual()

#### With that, we are done. THANK YOU EVERYONE!