# Inference of causal graphs from data

Author: Marcell Stippinger

Date: 2025-10-10

## Contents

* Generate and fit VAR and VMA models
* Analyze Granger causality

In [None]:
## Install neurophysiological data library
!pip install "mne[hdf5]"
!pip install edfio

## Imports and plotting functions

In [None]:
import numpy as np
import pandas as pd
from scipy import stats, signal
from sklearn.utils import check_random_state
from typing import NamedTuple, Optional, Tuple, Any
#stationarity testing
from statsmodels.tsa.stattools import adfuller, kpss
#autocorrelation and partial autocorrelation
from statsmodels.tsa.stattools import acf, pacf
#VAR and VMA models
from statsmodels.tsa.vector_ar.var_model import VAR
from statsmodels.tsa.api import VARMAX
# Granger causality test
from statsmodels.tsa.stattools import grangercausalitytests
# Visualization
import matplotlib.pyplot as plt

In [None]:
def plot_ts(X, E=None, fs=1.0):
    """Plot time series data.

    Parameters
    ----------
    X : array-like, shape (n_samples, n_features)
        Time series data to plot.
    E : array-like, shape (n_samples, n_features), optional
        Noise components to overlay on the time series.
    fs : float, optional
        Sampling frequency of the time series.
    """
    n_samples, n_features = X.shape
    t = np.arange(n_samples) / fs
    fig, axes = plt.subplots(n_features, 1, figsize=(6, 4), sharex=True)
    if n_features == 1:
        axes = [axes]
    for i in range(n_features):
        axes[i].axhline(0, color='gray', linestyle='--', linewidth=0.5)
        axes[i].plot(t, X[:, i])
        if E is not None:
            axes[i].plot(t, E[:, i], linestyle='None', marker='o', markersize=3, alpha=0.7, label='Noise')
        axes[i].set_title(f'Time Series {i+1}')
        axes[i].set_ylabel('Value')
    axes[-1].set_xlabel('Time')
    plt.tight_layout()
    plt.show()

## Explore some real-world data sets as well


## About stationarity tests

* ADF: null hypothesis: unit root (non-stationary), can conclude stationarity if $p$-value < alpha
* KPSS: null hypothesis: stationary, can conclude non-stationarity if $p$-value < alpha

In [None]:
help(adfuller)
# help(kpss)

In [None]:
class ADFullerResult(NamedTuple):
    adf : float
    pvalue : float
    usedlag : int
    nobs : int
    critical_values : dict
    icbest : float
    resstore : Any = None

#adf_result = ADFullerResult(*adfuller(ts_var[:, 0], maxlag=5))
#adf_result

In [None]:
class KPSSResult(NamedTuple):
    kpss_stat : float
    p_value : float
    lags : int
    crit : dict
    resstore : Any = None

#kpss_result = KPSSResult(*kpss(ts_var[:, 1], nlags='auto'))
#kpss_result

## Autocorrelograms

In [None]:
def plot_autocorrelograms(x: np.ndarray, lags: int = 40):
    """Plot ACF and PACF of a time series.

    Parameters
    ----------
    x : array-like, shape (n_samples,)
        Time series data.
    lags : int
        Number of lags to include in the plots.
    """
    acf_vals = acf(x, nlags=lags)
    pacf_vals = pacf(x, nlags=lags)

    fig, axes = plt.subplots(2, 1, figsize=(6, 4))

    axes[0].stem(range(lags + 1), acf_vals)
    axes[0].set_title('Autocorrelation Function (ACF)')
    axes[0].set_xlabel('Lags')
    axes[0].set_ylabel('ACF')

    axes[1].stem(range(lags + 1), pacf_vals)
    axes[1].set_title('Partial Autocorrelation Function (PACF)')
    axes[1].set_xlabel('Lags')
    axes[1].set_ylabel('PACF')

    plt.tight_layout()
    plt.show()

## Granger causality test

We can follow, for example

Ding, M., Chen, Y., & Bressler, S. L. (2006). Granger Causality: Basic Theory and Application to Neuroscience. February. https://doi.org/10.1002/9783527609970.ch17

In [None]:
# Y -> X
#gr_yx = grangercausalitytests(coupled_ts, maxlag=4)

In [None]:
# X -> Y
#coupled_ts_rev = np.stack((coupled_ts[:, 1], coupled_ts[:, 0]), axis=1)
#gr_xy = grangercausalitytests(coupled_ts[:, ::-1], maxlag=4)
#gr_xy = grangercausalitytests(coupled_ts_rev, maxlag=4)

Explain the results

- which tests are significant
- for what lag

# Data analysis example

### Air Temperature vs Solar Irradiation

In [None]:
!wget https://webdav.tuebingen.mpg.de/cause-effect/pair0077.txt
!wget https://webdav.tuebingen.mpg.de/cause-effect/pair0077_des.txt
!wget https://webdav.tuebingen.mpg.de/cause-effect/pair0077.pdf

In [None]:
ts = pd.read_csv('data/pair0077.txt', sep=' ', names=['temp', 'solar'], index_col=False)
ts.head()

In [None]:
plot_ts(ts.values)

In [None]:
ts_diff = np.diff(ts, axis=0)
plot_ts(ts_diff)

In [None]:
plot_autocorrelograms(ts.values[:, 0])
plot_autocorrelograms(ts.values[:, 1])

In [None]:
plot_autocorrelograms(ts_diff[:, 0])
plot_autocorrelograms(ts_diff[:, 1])

In [None]:
subsample = 50
for delay in range(0, 4):
    print('\n\ndelay', delay, 'temp->solar')
    gr_eegxy = grangercausalitytests(np.column_stack((ts_diff[delay:, 1], ts_diff[:-delay or None, 0])), maxlag=4)
    print('\n\ndelay', delay, 'solar->temp')
    gr_eegyx = grangercausalitytests(np.column_stack((ts_diff[delay:, 0], ts_diff[:-delay or None, 1])), maxlag=4)
    if delay == 0:
        continue
    print('\n\ndelay', -delay, 'temp->solar')
    gr_eegxy = grangercausalitytests(np.column_stack((ts_diff[:-delay or None, 1], ts_diff[delay:, 0])), maxlag=4)
    print('\n\ndelay', -delay, 'solar->temp')
    gr_eegyx = grangercausalitytests(np.column_stack((ts_diff[:-delay or None, 0], ts_diff[delay:, 1])), maxlag=4)


### Epileptic EEG Dataset @ Mendeley Data

This dataset includes the EEG of 6 epileptic patients recorded at the Epilepsy monitoring unit of the American university of Beirut Medical Center between January 2014 and July 2015. The data represents measurements from 21 scalp electrodes, following the 10-20 electrode system, sampled at 500 Hz . All channels have been bandpass filtered between 1/1.6 Hz and 70Hz while filtering out the 50Hz (electrical utility frequency).  Some channels have been omitted from specific recordings due to artifact constraints. 

* By Wassim Nasreddine
* Published: 16 March 2021| Version 1 | DOI: 10.17632/5pc2j46cbc.1
* Find here: https://data.mendeley.com/datasets/5pc2j46cbc/1

In [None]:
'''
{patient : [(hour, minute, second, duration) ... ]
'''
seizures_15 = {1: [(17, 18, 8, 50)],
               2: [(22, 49 , 24, 47)],
               3: [(2, 57, 4, 13)],   
               4: [(5, 3, 26, 56),
                   (6, 23, 29, 20)]}
            

seizures_14 =  {1: [(14,32,2,28),
                    (15,34,32,134)],
                2: [(16, 20, 58, 32), #no need for the replicates
                    (17, 50, 56, 10)],
                3: [(20, 20, 46, 31),
                    (21, 2, 4, 26),
                    (21, 27, 49, 40),
                    (21, 50, 24, 40)]}

seizures_13 =  {1: [(2, 31, 9, 52)],
                2: [(3, 33, 4, 25),
                    (4, 38, 59, 16)],
                3: [(6, 45, 51, 18)],
                4: [(10, 51, 41, 30),
                    (12, 18, 22, 24)]}

                      
seizures_12 =  {1: [(1, 56, 14, 76),
                    (2, 20, 12, 104)],
                2: [(5, 50, 55, 118)], 
                3: [(6, 40, 30, 82), #no need for the replicates
                    (7, 21, 36, 164),
                    (8, 37, 48, 113)]}

seizures_11 =  {1: [(15, 8, 55, 64)],
                2: [(18, 11, 38, 51)],
                3: [(19, 18, 38, 91),
                    (20, 6, 38, 83),
                    (20, 53, 22, 76),
                    (21, 27, 24, 73)],
                4: [(23, 55, 35, 1358)]}
                 
seizures_10 =  {1: [(7,36,38,445)],
                2: [(6,29,14,305)]}


In [None]:
# Download the epileptic dataset
!mkdir data

import requests
import pathlib
user_agent = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'
epileptic = {
    'x_test.npz': 'https://data.mendeley.com/public-files/datasets/5pc2j46cbc/files/93b81166-0e48-4dc0-ac20-b7167f7606c5/file_downloaded',
    'x_train.npz': 'https://data.mendeley.com/public-files/datasets/5pc2j46cbc/files/169dca1c-4992-43d3-9c94-030de59c2524/file_downloaded',
    'y_test.npz': 'https://data.mendeley.com/public-files/datasets/5pc2j46cbc/files/adf1c2fd-81ef-4f87-86cc-56d75bba8c31/file_downloaded',
    'y_train.npz': 'https://data.mendeley.com/public-files/datasets/5pc2j46cbc/files/62accb90-a1b2-4b50-bde5-fbe6096f165f/file_downloaded',

    'p10_Record1.edf': 'https://data.mendeley.com/public-files/datasets/5pc2j46cbc/files/6a2d23f5-9129-4970-9d1c-05f6012f9cd2/file_downloaded',
    'p11_Record1.edf': 'https://data.mendeley.com/public-files/datasets/5pc2j46cbc/files/d7189a9f-d737-4924-bf9b-7b78c624a6f7/file_downloaded',
    'p12_Record1.edf': 'https://data.mendeley.com/public-files/datasets/5pc2j46cbc/files/3e4f0e2d-2ad4-4af2-83a7-fc2e34bd933b/file_downloaded',
    'p12_Record2.edf': 'https://data.mendeley.com/public-files/datasets/5pc2j46cbc/files/4439823c-3fe0-4369-a50a-fa077167237d/file_downloaded',
    'p15_Record4.edf': 'https://data.mendeley.com/public-files/datasets/5pc2j46cbc/files/32b8a29a-04d7-4505-ae14-f0a42e293b93/file_downloaded',
    # There are more files, check on the webpage. These are enough for the examples.
}

headers = {
    "User-Agent": user_agent,
    "Referer": "https://data.mendeley.com/",
    "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
}
with requests.Session() as sess:
    sess.headers.update(headers)
    for key, url in epileptic.items():
        outpath = f"data/{key}"
        if pathlib.Path(outpath).exists():
            print("File", outpath, "already exists, skipping download.")
            continue
        print("Downloading", key, "->", outpath)
        try:
            r = sess.get(url, stream=True, timeout=30)
            if r.status_code == 200:
                with open(outpath, "wb") as f:
                    for chunk in r.iter_content(chunk_size=8192):
                        if chunk:
                            f.write(chunk)
                print("Saved", outpath)
            else:
                print("Failed", key, r.status_code, r.reason)
                print(" Response headers:", dict(r.headers))
        except requests.exceptions.RequestException as e:
            print("Error downloading", key, e)


In [None]:
file = "data/p11_Record1.edf"
timing = seizures_11[1]

In [None]:
# import mne
# data = mne.io.read_raw_edf(file)  # this may take several minutes depending on the file size
# info = data.info
# print(info)
# channels = data.ch_names
# print(channels)

In [None]:
import edfio
edf = edfio.read_edf(file)

In [None]:
Fs = edf.signals[0].sampling_frequency
print("Sampling frequency:", Fs)
data = {s.label: s.data for s in edf.signals}

In [None]:
import datetime as dt
edf.startdate, edf.starttime, edf.duration

In [None]:
delta = datetime.timedelta(seconds=edf.duration)
start = dt.datetime.combine(edf.startdate,edf.starttime)
print(start + delta)
offsets = [dt.datetime.combine(edf.startdate, dt.time(*tmp[0:3])) - start for tmp in timing]
durations = [dt.timedelta(seconds=tmp[3]) for tmp in timing]
print("Offsets:", offsets)
print("Durations:", durations)

In [None]:
data

In [None]:
example = np.column_stack((data['EEG Fp1-Ref'], data['EEG Cz-Ref']))
plot_ts(example, fs=Fs)

In [None]:
plot_ts(example[int(Fs*60):int(Fs*120)], fs=Fs)
plot_ts(example[int(Fs*offsets[0].total_seconds()): int(Fs*offsets[0].total_seconds() + Fs*60)], fs=Fs)
plot_ts(example[int(Fs*offsets[0].total_seconds() + Fs*60): int(Fs*offsets[0].total_seconds() + Fs*120)], fs=Fs)

In [None]:
def plot_spectrogram(data, fs, marks=()):
    f, t, Sxx = signal.spectrogram(data, fs)
    print("Spectrogram shape:", Sxx.shape)
    fig, ax = plt.subplots(2, 1, figsize=(12, 5), height_ratios=[3, 1], sharex=True)
    scaled = 10 * np.log10(Sxx)
    lim = np.percentile(scaled, [1., 99.])
    mesh = ax[0].pcolormesh(t/60, f, scaled, vmin=lim[0], vmax=lim[1], shading='gouraud')
    plt.colorbar(mesh, ax=ax, label='Intensity [dB]')
    ax[0].set_ylabel('Frequency [Hz]')
    ax[0].set_title('Spectrogram')
    ax[1].plot(t/60, np.sum(Sxx, axis=0), label='Power')
    ax[1].set_ylabel('Power')
    ax[1].set_xlabel('Time [min]')
    for mark in marks:
        ax[1].axvspan(mark[0]/60, (mark[0]+mark[1])/60, color='red', alpha=0.3)
    plt.show()

marks = [(offset.total_seconds(), duration.total_seconds()) for offset, duration in zip(offsets, durations)]
plot_spectrogram(data['EEG O1-Ref'], Fs, marks)

**Documentation:**

-	1: for Complex Partial Seizures 
(3034 matrices of size 19x500 corresponding to 3034 seconds of complex partial seizures)
-	2: for Electrographic Seizures
(705 matrices of size 19x500 correponding to 750 seconds of electrographic seizures)
-	3: for Video-detected Seizures with no visual change over EEG
(111 matrices of size 19x500 corresponding to 111 seconds of Video-detected Seizures with no visual change over EEG)
-	0: for Normal data
(3895 matrices of size 19x500 corresponding to 3895 seconds of normal data, 3895 is the total duration of all available seizures to create the balance between normal and lesional data).   


In [None]:
seizure_type = {0: 'Normal',
                1: 'ComplexPartial',
                2: 'Electrographic',
                3: 'VideoDetected'}

In [None]:

labels = ['EEGFp2Ref', 'EEGFp1Ref', 'EEGF8Ref', 'EEGF4Ref', 'EEGFzRef', 'EEGF3Ref', 'EEGF7Ref', 'EEGA2Ref', 'EEGT4Ref', 'EEGC4Ref', #'EEGCzRef',
          'EEGC3Ref', 'EEGT3Ref', 'EEGA1Ref', 'EEGT6Ref', 'EEGP4Ref', #'EEGPzRef',
          'EEGP3Ref', 'EEGT5Ref', 'EEGO2Ref', 'EEGO1Ref', #'ECGEKG', 'Manual', 'EDFAnnotations'
          ]
x_train = np.load('data/x_train.npz')
y_train = np.load('data/y_train.npz')

In [None]:
len(labels), x_train.shape, y_train.shape

In [None]:
y_train

In [None]:
ADFullerResult(*adfuller(example[:int(120*Fs), 0]))

In [None]:
plot_autocorrelograms(example[:int(1200*Fs), 0], lags=40)

In [None]:
plot_autocorrelograms(signal.decimate(example[:int(1200*Fs), 0], 50), lags=40)
plot_autocorrelograms(signal.decimate(example[:int(1200*Fs), 0], 500), lags=40)

In [None]:
example.shape

In [None]:
subsample = 50
gr_eegxy = grangercausalitytests(signal.decimate(example, subsample, axis=0)[:1000], maxlag=10)
gr_eegyx = grangercausalitytests(signal.decimate(example[:, ::-1], subsample, axis=0)[:1000], maxlag=10)