In [None]:
import numpy as np
from matplotlib import pyplot as plt
from numpy import pi as pi
from somata.oscillator_search import IterativeOscillatorModel
from somata.pac.pac_model import fit_pac_regression, kmod, phimod, optimize_arp

def rotation_matrix(Fs, f):
    return np.array([[np.cos(2*pi*f/Fs), -np.sin(2*pi*f/Fs)], [np.sin(2*pi*f/Fs), np.cos(2*pi*f/Fs)]])

def window(y, window_length = 500, window_stride = 100):
    """
    Splits data into windows

    :param y: _description_
    :param window_length: _description_
    :param window_stride: 
    :param min_length: Minimum length of signal remaining for there to be a new window
    """
    
    starts = np.arange(0, len(y) - window_length + 1, window_stride)
    ends = np.minimum(starts + window_length, len(y))
    
    return [y[starts[x]:ends[x]] for x in range(len(starts))]

# Generate data
def generate_pac_data(n = int(1e3), Fs = 100, f_slow = 0.5, a_slow = 0.9999, Q_slow = 1e-5, R_slow = 0.1, f_fast = 5, a_fast = 0.9999, Q_fast = 1e-5, R_fast = 0.1, k = 1, phi = 0, rng = np.random.default_rng(12345789)):
    """
    Generates a fast and slow signal with phase amplitude modulation

    :param n: Number of data points, defaults to int(1e3)
    :param Fs: Sampling frequency, defaults to 100
    :param f_slow: Frequency of slow oscillation, defaults to 0.5
    :param a_slow: Damping factor of slow oscillation, defaults to 0.9999
    :param Q_slow: State noise of slow oscillation, defaults to 1e-5
    :param R_slow: Observation noise of slow oscillation, defaults to 1
    :param f_fast: Frequency of fast oscillation, defaults to 10
    :param a_fast: Damping factor of fast oscillation, defaults to 0.9999
    :param Q_fast: State noise of fast oscillation, defaults to 1e-5
    :param R_fast: Observation noise of fast oscillation, defaults to 1
    :param k: Magnitude of phase-amplitude coupling, defaults to 1
    :param phi: Phase of slow oscillation at which amplitude of fast oscillation peaks, defaults to 0
    :param rng: Random number generator, seeded for reproducibility

    :return: Tuple of t (time), y_slow (Slow signal), y_fast (Fast signal)
    """
    y_slow = np.empty(n)
    y_fast = np.empty(n)
    t = np.arange(0, n/Fs, 1/Fs)

    F_slow = rotation_matrix(Fs, f_slow)
    F_fast = rotation_matrix(Fs, f_fast)

    x_slow = np.array([1., 0.])
    x_fast = np.array([1., 0.])
    
    for i in range(n):
        y_slow[i] = x_slow[0] + rng.normal(0, R_slow)
        y_fast[i] = x_fast[0] * (1 + k * np.cos(np.arccos(x_slow[1]/np.sqrt(x_slow[0] ** 2 + x_slow[1] ** 2) - phi))) + rng.normal(0, R_fast)
        
        x_slow = a_slow * F_slow @ x_slow + rng.multivariate_normal([0, 0], np.diag([Q_slow, Q_slow]))
        x_fast = a_fast * F_fast @ x_fast + rng.multivariate_normal([0, 0], np.diag([Q_fast, Q_fast]))
    
    return t, y_slow, y_fast



In [None]:
n = int(2e3)
Fs = 100
t, y_slow, y_fast = generate_pac_data(n, Fs)

In [None]:
plt.plot(t, y_slow)
plt.plot(t, y_fast)
plt.show()

In [None]:
iosc_slow = IterativeOscillatorModel(y_slow, 100)
iosc_slow.iterate()

iosc_fast = IterativeOscillatorModel(y_fast, 100)
iosc_fast.iterate()

In [None]:
fitted_slow = iosc_slow.get_knee_osc()
fitted_fast = iosc_fast.get_knee_osc()

print(fitted_slow)


In [None]:

print(fitted_fast)

In [None]:
x_slow = fitted_slow.kalman_filt_smooth(y= y_slow[None, :], return_dict = True)['x_t_n']
x_fast = fitted_fast.kalman_filt_smooth(y= y_fast[None, :], return_dict = True)['x_t_n']

In [None]:
amplitude_slow = np.sqrt(np.sum(x_slow[:,1:] ** 2, 0))
amplitude_fast = np.sqrt(np.sum(x_fast[:,1:] ** 2, 0))

phase_slow = np.arctan2(x_slow[1,1:], x_slow[0,1:])
phase_fast = np.arctan2(x_fast[1,1:], x_fast[0,1:])

In [None]:
plt.plot(t, amplitude_slow)
plt.plot(t, amplitude_fast)
plt.show()
         

In [None]:
plt.plot(t, phase_slow)
plt.plot(t, phase_fast)
plt.show()
         

In [None]:
# Window phase of slow oscillation and amplitude of fast oscillation
beta = fit_pac_regression(phase_slow, amplitude_fast)

In [None]:
beta_map = np.mean(beta, 0)
fitted_k = kmod(beta_map[0], beta_map[1], beta_map[2])
fitted_phi = phimod(beta_map[1], beta_map[2])
print(f"kmod = {fitted_k:.5f}, phimod = {fitted_phi:.5f}")

In [None]:
# Windowed phase
windowed_phase = window(phase_slow, 250, 100)
windowed_amplitude = window(amplitude_fast, 250, 100)
windowed_beta = [fit_pac_regression(windowed_phase[i], windowed_amplitude[i]) for i in range(len(windowed_phase))]

In [None]:
windowed_beta_map = np.row_stack([np.mean(x, 0) for x in windowed_beta]).T
# windowed_beta_map = windowed_beta_map - np.mean(windowed_beta_map, 1)[:,None]

In [None]:
plt.plot(np.arange(0, windowed_beta_map.shape[-1], 1), windowed_beta_map[0,:])
plt.plot(np.arange(0, windowed_beta_map.shape[-1], 1), windowed_beta_map[1,:])
plt.plot(np.arange(0, windowed_beta_map.shape[-1], 1), windowed_beta_map[2,:])
plt.show()

In [None]:
from somata.pac.pac_model import autocovariances, block_toeplitz, ar_parameters, mvar_ssm
A, Q, R = optimize_arp(windowed_beta_map, 1)
model = mvar_ssm(windowed_beta_map, A, Q, R)

In [None]:
z = model.kalman_filt_smooth(return_dict=True)

plt.plot(z['x_t_t'][0,1:])
plt.plot(z['x_t_t'][1,1:])
plt.plot(z['x_t_t'][2,1:])
plt.show()

x = z['x_t_t'][:,1:]
