# Fading Memory Systems - Correlation Analysis

Some correlation analysis and playground around FMS.


In [None]:
%reload_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import numpy as np
import scipy.signal as ssig
import torch
import torch.nn as nn

from src.models.simple_cnn import SimpleCNN
from src.utils.plotting import init_plot_style

# initialize our global plot style
init_plot_style()


In [None]:
class FMS:
    def __init__(self, in_size: int, mem_size: int, autoregressive: bool = False, mem_decay: float = 0.9,
                 pnoise: float = 0.01):
        """

        Parameters
        ----------
        in_size: int
            Number of channels of the input signal.
        mem_size: int
            Memory lag of the system, i.e., the system will use current + the past mem_size - 1 input samples to compute the next output.
        autoregressive: bool
            If true, uses the past mem_size - 1 output samples to compute the next output.
        mem_decay:float
            The system uses an internal, exponentially decaying memory with decay factor mem_decay.
        pnoise:float
            Standard deviation of the included process noise.
        """

        self.in_size = in_size
        self.mem_size = mem_size
        self.autoregressive = autoregressive
        self.mem_decay = mem_decay
        self.process_noise = pnoise
        self.hidden_size = 32

        num_features = (self.in_size + 1) * self.mem_size if autoregressive else self.in_size * self.mem_size + 1
        self.mlp = nn.Sequential(nn.Linear(num_features, self.hidden_size),
                                 nn.LeakyReLU(),
                                 nn.Linear(self.hidden_size, 1),
                                 )

    def __call__(self, x: torch.Tensor):
        # we expect input tensors of shape (in_size, sequence_length)
        assert torch.is_tensor(x) and x.dim() == 2 and x.shape[0] == self.in_size, print(x.__class__, x.shape)
        seq_length = x.shape[-1]

        x_padded = nn.functional.pad(x, (self.mem_size - 1, 0))
        assert x_padded.shape == (self.in_size, seq_length + self.mem_size - 1)

        output = torch.zeros(1, self.mem_size - 1 + seq_length)
        memory = torch.tensor(0.)
        with torch.no_grad():
            for n in range(seq_length):
                z = torch.cat(
                    (x_padded[:, n:n + self.mem_size].flatten(),  # past inputs
                     torch.randn(1) * self.process_noise),  # process noise
                    dim=-1)

                if self.autoregressive:
                    # append past outputs
                    z = torch.cat((z, output[0, n:n + self.mem_size - 1]), dim=-1)

                z = self.mlp(z)
                memory = self.mem_decay * memory + z
                output[0, n + self.mem_size - 1] = torch.cos(memory + 10. * z).abs()

        return output[:, self.mem_size - 1:]



In [None]:

# instantiate the system
system = FMS(in_size=1, mem_size=5, autoregressive=True, mem_decay=0.99, pnoise=1e-2)


In [None]:

# generate test signal and compute system output
n_samples = int(1e4)
test_signal = torch.rand(n_samples)  # uniform white noise
#test_signal = torch.ones( n_samples) # constant input
#test_signal = torch.cos(0.1*torch.pi*torch.arange(n_samples)) # sinusoidal test signal
test_output = system(test_signal.unsqueeze(0)).squeeze(0)


In [None]:

# plot the input/output behavior
plt.close('all')
plt.figure(figsize=(12, 6))
plt.plot(test_signal[:100], label='Test Input, $x[n]$')
plt.plot(test_output[:100], label='Test Output, $y[n]$')
plt.xlabel('Time Index, $n$')
plt.ylabel('$x[n]$, $y[n]$')
plt.legend()
plt.tight_layout()

# compute correlation function
max_lag = 25
threshold = 0.1
window = torch.hann_window(n_samples)
#x = test_signal
#y = test_output

x = (test_signal - test_signal.mean()) * window
y = (test_output - test_output.mean()) * window

window_acf = np.abs(ssig.correlate(window, window, mode='full')) + 1e-3
#window_acf = 1.

ycorr = ssig.correlate(y, y, mode='full') / window_acf
ycorr = ycorr / ycorr[n_samples - 1]
xycorr = ssig.correlate(x, y, mode='full') / window_acf
xycorr = xycorr / abs(xycorr).max()

lags = ssig.correlation_lags(n_samples, n_samples, mode='full')
threshold_band = torch.ones(len(lags[np.abs(lags) < max_lag])) * threshold

# plot the correlation function
plt.figure(figsize=(12, 6))
plt.fill_between(lags[np.abs(lags) < max_lag], -threshold_band, threshold_band, color='gray', alpha=0.5)
plt.plot([0, 0], [-1, 1], 'r--')
plt.plot([-system.mem_size + 1, -system.mem_size + 1], [-1, 1], 'r--')
plt.stem(lags[np.abs(lags) < max_lag], ycorr[np.abs(lags) < max_lag])
# plt.stem(lags[np.abs(lags) < max_lag], window_acf[np.abs(lags) < max_lag])
plt.xlabel('Time Lag, $k$')
plt.ylabel('Autocorrelation')
plt.tight_layout()

# plot the correlation function
plt.figure(figsize=(12, 6))
plt.fill_between(lags[np.abs(lags) < max_lag], -threshold_band, threshold_band, color='gray', alpha=0.5)
plt.plot([0, 0], [-1, 1], 'r--')
plt.plot([-system.mem_size + 1, -system.mem_size + 1], [-1, 1], 'r--')
plt.stem(lags[np.abs(lags) < max_lag], xycorr[np.abs(lags) < max_lag])
plt.xlabel('Time Lag, $k$')
plt.ylabel('Cross-Correlation')
plt.tight_layout()

We can now generate a training and test set by generating random test signals,
propagating them over the system and corrupting them with additive white Gaussian noise.

In [None]:
n_train = 500  # length of training signals
n_test = 200  # length of test signals
noise_std = 0.4
rng = np.random.default_rng(seed=0)

x_train = torch.rand((1, n_train)) * 2.0 - 1.0  # draw samples from support
noise = torch.randn((1, n_train)) * noise_std  # generate noise
y_train = system(x_train) + noise  # simulate measurements

x_test = torch.rand((1, n_test)) * 2.0 - 1.0  # draw samples from support
noise = torch.randn((1, n_test)) * noise_std  # generate noise
y_test = system(x_test) + noise  # simulate measurements

In [None]:
# create and fit an instance of our simple CNN model given our training data
cnn_model = SimpleCNN(num_kernels=30, mem_depth=20)
train_loss_list = cnn_model.fit(x_train, y_train, learning_rate=1e-2, max_epochs=400)

# plot the evolution of the training MSE
plt.figure(figsize=(12, 6))
plt.plot(list(range(1, 1 + len(train_loss_list))), train_loss_list)
plt.xlabel('Epoch')
plt.ylabel('Training Set MSE')
plt.tight_layout()


In [None]:
# compute train/test MSEs and plot prediction vs output

prediction = cnn_model(x_train).detach().squeeze(0)
train_mse = ((prediction - y_train) ** 2).mean() / (y_train ** 2).mean()
print(f'Training MSE is {train_mse}.')
plt.figure(figsize=(12, 6))
plt.plot(y_train.squeeze(0), label='True Training Output, $y[n]$')
plt.plot(prediction, label=r'Predicted Output, $\hat{y}[n]$')
plt.xlabel('Time Index, $n$')
plt.ylabel(r'$y[n]$, $\hat{y}[n]$')
plt.legend()
plt.tight_layout()

prediction = cnn_model(x_test).detach().squeeze(0)
test_mse = ((prediction - y_test) ** 2).mean() / (y_test ** 2).mean()
print(f'Test MSE is {test_mse}.')
plt.figure(figsize=(12, 6))
plt.plot(y_test.squeeze(0), label='True Test Output, $y[n]$')
plt.plot(prediction, label=r'Predicted Output, $\hat{y}[n]$')
plt.xlabel('Time Index, $n$')
plt.ylabel(r'$y[n]$, $\hat{y}[n]$')
plt.legend()
plt.tight_layout()