# Inference of causal graphs from data

Author: Marcell Stippinger

Date: 2025-10-03

## Contents

* Learn how to generate a causal graph
* Create probability distribution over the graph
* Sample from the probability distribution
* Infer causal relations from data (PC, FCI)

## Preparation, global imports

In [None]:
# Install package if not already installed
!pip install causal-learn

In [None]:
from copy import copy, deepcopy
from IPython.display import Image
import numpy as np
import pandas as pd
import networkx as nx
from sklearn.utils import check_random_state
from scipy import stats

## Generate graph and distribution

In [None]:
def dag_with_preferential_attachment(num_nodes, root_cause_proba=0.3, parent_count=2.0, random_seed=20251003):
    """Generate a random DAG with preferential attachment.

    Parameters
    ----------
    num_nodes : int
        Number of nodes in the DAG.
    root_cause_proba : float
        Probability of a node being a root cause (having no parents).
    parent_count : float
        Average number of parents for non-root nodes.
    random_seed : int
        Seed for the random number generator.
        

    Returns
    -------
    dag : dict
        A dictionary representing the DAG, where keys are node indices and values are lists of parent node indices.
        Note: in this representation indices are in a possible temporal order (parents always have smaller indices than their children).
    """
    rng = check_random_state(random_seed)
    dag = {}
    i_node = 0
    preference = []

    for i_node in range(num_nodes):
        if stats.uniform.rvs(random_state=rng) < root_cause_proba or i_node == 0:
            # Create a new root cause node
            dag[i_node] = []
            preference.append(1)
            i_node += 1
        else:
            # Create a new node with parents
            potential_parents = list(dag.keys())
            num_parents = stats.poisson.rvs(np.maximum(parent_count - 1, 0), random_state=rng) + 1
            parents = rng.choice(potential_parents, size=np.minimum(num_parents, i_node), replace=False, p=preference/np.sum(preference)).tolist()
            dag[i_node] = parents
            preference.append(1)
            for p in parents:
                preference[p] += 1
            i_node += 1

    return dag

example_dag = dag_with_preferential_attachment(10, root_cause_proba=0.1, parent_count=2.0,)
print(example_dag)

In [None]:
def sprinkler_dag():
    """Generate the sprinkler example DAG.

    Returns
    -------
    dag : dict
        A dictionary representing the DAG, where keys are node indices and values are lists of parent node indices.
        Note: in this representation indices are in a possible temporal order (parents always have smaller indices than their children).
    """
    dag = {
        0: [],       # Season
        1: [0],      # Sprinkler
        2: [0],      # Rain
        3: [1, 2],   # Wet Grass
        4: [3]       # Slippery Road
    }
    return dag

example_dag = sprinkler_dag()

In [None]:
def visualize_dag(dag):
    """Visualize a DAG using NetworkX and Matplotlib.

    Parameters
    ----------
    dag : dict
        A dictionary representing the DAG, where keys are node indices and values are lists of parent node indices.
    """
    import matplotlib.pyplot as plt

    G = nx.DiGraph()
    for child, parents in dag.items():
        G.add_node(child)
        for parent in parents:
            G.add_edge(parent, child)

    pos = nx.spring_layout(G)
    nx.draw(G, pos, with_labels=True, arrows=True)
    plt.show()

visualize_dag(example_dag)

In [None]:
def create_probability_distributions(dag, states=2, random_seed=20251003):
    """Create probability distributions for each node in the DAG.

    Parameters
    ----------
    dag : dict
        A dictionary representing the DAG, where keys are node indices and values are lists of parent node indices.
    states : int
        Number of discrete states for each node.
    random_seed : int
        Seed for the random number generator.

    Returns
    -------
    distributions : dict
        A dictionary where keys are node indices and values are probability distributions as a function of the state of their parents.
    """
    rng = check_random_state(random_seed)
    distributions = {}

    for node, parents in dag.items():
        if not parents:
            # Root node: uniform distribution
            prob = rng.dirichlet(np.ones(states))
            assert prob.shape == (states, )
            assert np.allclose(prob.sum(), 1)
            distributions[node] = prob
        else:
            # Non-root node: conditional probability table
            parent_states = [states] * len(parents)
            table_shape = parent_states + [states]
            prob_table = rng.dirichlet(np.ones(states), size=np.prod(parent_states)).reshape(table_shape)
            assert prob_table.shape == tuple(parent_states + [states])
            assert np.allclose(prob_table.sum(axis=-1), 1)
            distributions[node] = prob_table

    return distributions

example_distro = create_probability_distributions(example_dag)
print(example_distro)

## Example data

In [None]:
def generate_observations(dag, distributions, num_samples=10000, random_seed=20251003):
    """Generate observations from the DAG and its probability distributions.

    Parameters
    ----------
    dag : dict
        A dictionary representing the DAG, where keys are node indices and values are lists of parent node indices.
    distributions : dict
        A dictionary where keys are node indices and values are probability distributions as a function of the state of their parents.
    num_samples : int
        Number of samples to generate.
    random_seed : int
        Seed for the random number generator.

    Returns
    -------
    data : pd.DataFrame
        A DataFrame where each column corresponds to a node and each row is a sample.
    """
    rng = check_random_state(random_seed)
    data = pd.DataFrame(index=np.arange(num_samples), columns=dag.keys())

    for node in dag.keys():
        parents = dag[node]
        if not parents:
            # Root node: sample from its distribution
            prob = distributions[node]
            data[node] = rng.choice(len(prob), size=num_samples, p=prob)
        else:
            # Non-root node: sample based on parents' states
            prob_table = distributions[node]
            parent_states = data[parents].values
            samples = []
            # we could use np.take to parallelize this, but it's less readable
            for ps in parent_states:
                prob = prob_table[tuple(ps)]
                sample = rng.choice(len(prob), p=prob)
                samples.append(sample)
            data[node] = samples

    return data

example_data = generate_observations(example_dag, example_distro, num_samples=10000)
example_data.head()

## Inference of causal graphs from data

### The PC algorithm (Peter and Clark)

Spirtes, P., Glymour, C. N., Scheines, R., & Heckerman, D. (2000). Causation, prediction, and search. MIT press.

**Output:** cg : a CausalGraph object,
* where cg.G.graph[j,i]=1 and cg.G.graph[i,j]=-1 indicate i –> j;
* cg.G.graph[i,j] = cg.G.graph[j,i] = -1 indicate i — j;
* cg.G.graph[i,j] = cg.G.graph[j,i] = 1 indicates i <-> j.


In [None]:
from causallearn.search.ConstraintBased.PC import pc

# default parameters
cg = pc(example_data.values)

# or customized parameters
#cg = pc(data, alpha, indep_test, stable, uc_rule, uc_priority, mvpc, correction_name, background_knowledge, verbose, show_progress)

# visualization using pydot
cg.draw_pydot_graph(labels=np.arange(example_data.shape[1]))

### FCI algorith (Fast Causal Inference)

Spirtes, P., Meek, C., & Richardson, T. (1995, August). Causal inference in the presence of latent variables and selection bias. In Proceedings of the Eleventh conference on Uncertainty in artificial intelligence (pp. 499-506).

**Output:** g: a GeneralGraph object,
* where g.graph is a PAG and the illustration of its end nodes is as follows (denotes G = g.graph):
* $A\to B$ cause if G[B,A]=1 and G.graph[A,B]=-1
* $A\circ\!\!\to B$ not ancestor if G[B,A]=1 and G[A,B]=2
* $A\circ\!\!-\!\!\circ B$ no set d-separates A and B if G[A,B] = G[B,A] = 2
* $A\leftrightarrow B$ latent confounder if G[A,B] = G[B,A] = 1

In [None]:
from causallearn.search.ConstraintBased.FCI import fci

# default parameters
g, edges = fci(example_data.values)

# or customized parameters
# g, edges = fci(data, independence_test_method, alpha, depth, max_path_length, verbose, background_knowledge, cache_variables_map)

# visualization using pydot
from causallearn.utils.GraphUtils import GraphUtils

pdy = GraphUtils.to_pydot(g, labels=np.arange(example_data.shape[1]))
#pdy.write_png('simple_test.png')
Image(pdy.create_png(prog='dot'))

Excercises:
* try different sample sizes
* try different graph structures
* can you write down the independence and conditional independence relations
* does the output belong to an observationally equivalent graph