# Cross Correlation
- Cross correlation measures **“how strongly two EEG channels look like the same waveform, allowing for a small time shift”**
- Below is a sample function for cross correlation

In [None]:
import numpy as np

def feat_xcor(x: np.ndarray, max_lag_s: float, fs: float, pairs: list[tuple[int,int]] | None = None) -> np.ndarray:
    """
    XCOR block:
    For selected channel pairs, compute normalized cross-correlation peak within +/- max_lag.
    """
    # Get shape
    n_ch, n = x.shape
    # Get max lag in samples
    max_lag = int(max_lag_s * fs)

    if pairs is None:
        # default: a small set of pairs to limit feature size (first 6 channels)
        m = min(n_ch, 6)
        pairs = [(i, j) for i in range(m) for j in range(i+1, m)]

    # Cross-correlation proper
    feats = []
    for i, j in pairs:
        xi = x[i] - x[i].mean()
        xj = x[j] - x[j].mean()

        denom = (np.linalg.norm(xi) * np.linalg.norm(xj) + 1e-12)
        # full correlation via FFT is faster, but direct is fine for small windows
        corr = np.correlate(xi, xj, mode="full") / denom
        mid = len(corr) // 2
        lo = max(0, mid - max_lag)
        hi = min(len(corr), mid + max_lag + 1)
        peak = np.max(np.abs(corr[lo:hi]))
        feats.append(peak)

    return np.array(feats, dtype=np.float32)

- Step 1: Get the pair of channels
  - The idea for getting the pairs is to create unique `(i,j)` pairs with `i < j`
  - So for example if we have 6 channels we can get a total of 15 pairs:
$$(1,2), (1,3), (1,4), (1,5), (1,6), (2,3), (2,4), (2,5), (2,6), (3,4), (3,5), (3,6), (4,5), (4,6), (5,6)$$

- Step 2: For each pair, we first remove the DC offset or the mean:
```python
    xi = x[i] - x[i].mean()
    xj = x[j] - x[j].mean()
```
- This removes baseline offsets, makes correlation focus more on the correlations, and helps correlation interpret shape similarities instead.

- Step 3: Normalize by energy:
```python
    denom = (np.linalg.norm(xi) * np.linalg.norm(xj) + 1e-12)
```
  - Note that `np.linalg.norm(xi)` is the L2 norm: $||x||_2 = \sqrt{\sum_t x[t]^2}$
  - Product of norms is proportional to product of signal energies
  - Dividing the correlation with this makes the normalization become a normalized similarity constrained within $[-1,1]$
  - The `+1e-12` only avoids division by 0 for safety reasons.

- Step 4: Compute cross correlation across all lags ("full")
```python
corr = np.correlate(xi, xj, mode="full") / denom
```
  - Here, correlation is computed for all possible integer shifts
  - If each signal has length $N$ then the correlation output will have length $2N-1$
  - The middle index corresponds to `lag = 0` or no shift
  - Left side corresponds to negative lags while right-side corresponds to positive lags
  - Mathematically this is:
  $$r_{ij}[\mathcal{l}] = \frac{\sum_t x_i[t]x_j[t+\mathcal{l}]}{||x_i||_2 ||x_j||_2}$$

- Step 5: Find the correlation between +max and -max lag only
```python
    mid = len(corr) // 2
    lo = max(0, mid - max_lag)
    hi = min(len(corr), mid + max_lag + 1)
```
  - `mid` is the index for lag 0
  - The slice `[lo:hi]` keeps lags in: $\mathcal{l} \in [-\textrm{max lag}, +\textrm{max lag}]$
  - This is a necessity for EEG because we are only interested in near synchronous (same time) or slightly delayed coupling only
  - It's not like we are measuring doppler effect or so. 

- Step 6: Take the peak magnitude (absolute value)
```python
peak = np.max(np.abs(corr[lo:hi]))
feats.append(peak)
```

## What does `np.correlate(x, y, mode="full")` do?
- Numpy defines this as:
$$\left( \textrm{correlate} (x,y) \right)[k] = \sum_n x[n] y[n+k] $$
- And samples outside the valid range are treated as zero
- Notes:
  - This is correlation and NOT convolution
  - No time-reversal happens
  - One signal is slid across the other and a dot product is computed for every shift
- What does the full mean?
  - Suppose `x` has length $N$ and `y` has length $M$
  - Then `len(np.correlate(x, y, mode="full")) == N + M - 1`
  - The output contains all possible shifts from $k = -(M-1)$ up to $N-1$
  - In other words, "compute correlation at every possible overlap including partial overlaps"
  - The $k = -(M-1)$ is like we slide the `y` along the `x`
- How does one interpret lag here?
  - Assume `x` and `y` have the same equal length of $N$
  - Then output length is $2N-1$
  - The middle or center index is 0
  - Refer to the table below for context.
  - So basically, the output of the correlation is a long series of data where the middle point is the center and the left and right relative to the center tells you how much the correlation stores are.

| Index range     | Lag meaning                     |
| --------------- | ------------------------------- |
| `corr[mid]`     | zero lag (aligned)              |
| `corr[mid + k]` | `y` shifted **forward** by `k`  |
| `corr[mid - k]` | `y` shifted **backward** by `k` |

- Take note that normalization is NOT part of the `np.correlate` function hence we need to divide it with the denominator that is the normalized energy: `corr = np.correlate(xi, xj, mode="full") / denom`

- Here is a simple example, consider:

```python
    x = [1, 2, 3]
    y = [4, 5, 6]
```
- Note that $N=3$ and we expect the output length to be $2N-1 = 2(3)-1 = 5$
- Where index 0 is the middle, -2 is the far left and +2 is the far right.
- Visually this is how it computes it:
- At lag = -2 overlap $\rightarrow$ $1 \times 6 = 6$
```python
    x:  1   2   3
    y:  6
```
- At lag = -1 overlap $\rightarrow$ $1 \times 5  + 2 \times 6 = 17$
```python
    x:  1   2   3
    y:  5   6
```
- At lag = 0 overlap $\rightarrow$ $1 \times 4  + 2 \times 5 + 3 \times 6 = 32$
```python
    x:  1   2   3
    y:  4   5   6
```
- At lag = +1 overlap $\rightarrow$ $2 \times 4  + 3 \times 5 = 23$
```python
    x:  1   2   3
    y:      4   5
```
- At lag = +2 overlap $\rightarrow$ $3 \times 4 = 12$
```python
    x:  1   2   3
    y:          4
```
- Therefore the `np.correlate(x,y,mode="full")` returns: `[6,17,32,23,12]`