<a href="https://colab.research.google.com/github/zhangcun-yan/Awesome-Interaction-Aware-Trajectory-Prediction/blob/master/model_causal_inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!python -m pip show scikit-learn

In [None]:
!pip install gcastle

In [None]:
!pip uninstall networkx

In [None]:
!pip install networkx==2.8.0

In [None]:
# import packages
import os
os.environ['CASTLE_BACKEND'] = 'pytorch'
from collections import OrderedDict
import warnings
import numpy as np
import networkx as nx
from scipy import linalg
# from sklearnear_model import LinearRegression
import castle
from castle.common import GraphDAG
from castle.metrics import MetricsDAG
from castle.datasets import DAG, IIDSimulation
from castle.algorithms import PC, GES
from castle.algorithms import ANMNonlinear, ICALiNGAM, DirectLiNGAM
from castle.algorithms import Notears, NotearsNonlinear, GOLEM
from castle.common.priori_knowledge import PrioriKnowledge
from castle.common.independence_tests import hsic_test
import matplotlib.pyplot as plt
import pandas as pd

In [None]:
COLORS = [
    '#00B0F0',
    '#FF0000',
    '#B0F000'
]

In [None]:
# Set random seed
SEED = 18
np.random.seed(SEED)

In [None]:
# Generate a scale-free adjacency matrix
adj_matrix = DAG.scale_free(
    n_nodes=10,
    n_edges=17,
    seed=SEED
)

In [None]:
# Visualize the adjacency matrix
g = nx.DiGraph(adj_matrix) # convert the matrix into graph
plt.figure(figsize=(8, 6))
nx.draw_networkx(
    G=g, # graph
    node_color=COLORS[0],
    node_size=1200,
    arrowsize=17,
    with_labels=True,
    font_color='white',
    font_size=21,
    pos=nx.circular_layout(g)
)

In [None]:
# notear
df_causal = pd.read_csv('./ttc_event_var.csv')
df_causal = pd.DataFrame(df_causal)
# df_causal = df_causal.drop(columns=['slope_ttc'])
df_causal.columns

### PC

In [None]:
import pandas as pd
from castle.common import GraphDAG
from castle.metrics import MetricsDAG
from castle.datasets import DAG, IIDSimulation
from castle.algorithms import PC
from castle.common.priori_knowledge import PrioriKnowledge
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image, ImageDraw

In [None]:
method = 'linear'
sem_type = 'gauss'
n_nodes = df_causal.shape[1]
n_edges = 10
n = 20000

In [None]:
df_causal.columns

In [None]:
pc = PC(variant='original',alpha=0.001)
pc.learn(df_causal)
pc.causal_matrix
# Convert the matrix to a pandas DataFrame
df = pd.DataFrame(pc.causal_matrix)
# plot predict_dag and true_dag
GraphDAG(pc.causal_matrix)

In [None]:
g = nx.DiGraph(pc.causal_matrix)
plt.figure(figsize=(8, 6))
nx.draw_networkx(
    G=g,
    node_color=COLORS[0],
    node_size=1000,
    arrowsize=8,
    with_labels=True,
    font_color='black',
    font_size=20,
    pos=nx.circular_layout(g)
)

In [None]:
# simulation for pc
weighted_random_dag = DAG.erdos_renyi(n_nodes=n_nodes, n_edges=n_edges, weight_range=(0.5,20), seed=1)
X = df_causal

# PC learn
priori = PrioriKnowledge(X.shape[1])
priori.add_required_edges([(1,0),(2,0),(3,0),(4,0),(5,0),(6,0),(7,0)])
priori.add_forbidden_edges([(0,1),(0,3),(0,5),(0,6),(0,8),(0,4),(0,14),(1,14),(2,14),(3,14),(4,14),(5,14),(6,14),(7,14),(8,14),(9,14),(10,14),(11,14),(12,14),(13,14),(15,14),(16,14),(0,9),(1,9),(2,9),(3,9),(4,9),(5,9),(6,9),(7,9),(8,9),(14,9),(10,9),(11,9),(12,9),(13,9),(15,9),(16,9)])
pc = PC(variant='original', priori_knowledge=priori,alpha=0.001)
X = pd.DataFrame(X, columns=list(['ACT', 'DIS', 'D_MTC', 'D_NMTC','mv_v', 'nmv_v', 'mv_acc_m','nmv_acc_m', 'NDIR', 'MN_angle', 'mv_angle_M', 'nmv_angle_M','mv_angle_V', 'nmv_angle_V', 'light_pos', 'mv_env_entropy','nmv_env_entropy']))

In [None]:
pc.learn(X)
pc.causal_matrix
# Convert the matrix to a pandas DataFrame
df = pd.DataFrame(pc.causal_matrix)
# plot predict_dag and true_dag
GraphDAG(pc.causal_matrix)

In [None]:
COLORS = [
    '#00B0F0',
    '#FF0000',
    '#B0F000'
]

In [None]:
g = nx.DiGraph(pc.causal_matrix)
plt.figure(figsize=(8, 6))
nx.draw_networkx(
    G=g,
    node_color=COLORS[0],
    node_size=1000,
    arrowsize=8,
    with_labels=True,
    font_color='black',
    font_size=20,
    pos=nx.circular_layout(g)
)
# plt.savefig(r'D:\dataset\Intersection\Data_processing\model\Variable_importance\figure/nt_pc.pdf', format='pdf')

## GES

In [None]:
from castle.datasets import DAG, IIDSimulation
from castle.common import GraphDAG
from castle.metrics import MetricsDAG
from castle.algorithms.ges.ges import GES

X = df_causal

for d in [3, 6, 9, 12, 15]:
    edges = d * 2
    weighted_random_dag = DAG.erdos_renyi(n_nodes=d, n_edges=edges,
                                          weight_range=(0.005, 2.00), seed=1)
    # dataset = IIDSimulation(W=weighted_random_dag, n=1000,
    #                         method='nonlinear', sem_type='gp-add')
    # true_dag, X = dataset.B, dataset.X
    algo = GES(criterion='bic',method='scatter')
    algo.learn(X)

    # plot predict_dag and true_dag
    GraphDAG(algo.causal_matrix)
    # m1 = MetricsDAG(algo.causal_matrix)
    # print(m1.metrics)
    break

In [None]:
g = nx.DiGraph(algo.causal_matrix)
plt.figure(figsize=(8, 6))
nx.draw_networkx(
    G=g,
    node_color=COLORS[0],
    node_size=1200,
    arrowsize=17,
    with_labels=True,
    font_color='white',
    font_size=21,
    pos=nx.circular_layout(g)
)

In [None]:
ges = GES(criterion='bic',method='scatter')
ges.learn(X)

In [None]:
GraphDAG(ges.causal_matrix)

In [None]:
g = nx.DiGraph(ges.causal_matrix)
plt.figure(figsize=(8, 6))
nx.draw_networkx(
    G=g,
    node_color=COLORS[0],
    node_size=1200,
    arrowsize=17,
    with_labels=True,
    font_color='white',
    font_size=21,
    pos=nx.circular_layout(g)
)

## golem

In [None]:
from castle.common import GraphDAG
from castle.metrics import MetricsDAG
from castle.datasets import DAG, IIDSimulation
from castle.algorithms import GOLEM


#######################################
# GOLEM used simulate data
#######################################
# simulate data for GOLEM
weighted_random_dag = DAG.erdos_renyi(n_nodes=df_causal.shape[1], n_edges=20, weight_range=(0.5, 2.0), seed=1)

# dataset = IIDSimulation(W=weighted_random_dag, n=2000, method='linear', sem_type='gauss')
# true_dag, X = dataset.B, dataset.X

# GOLEM learn
g = GOLEM(num_iter=1e4)
g.learn(df_causal)

# plot est_dag and true_dag
GraphDAG(g.causal_matrix)

In [None]:
g = nx.DiGraph(g.causal_matrix)
plt.figure(figsize=(8, 6))
nx.draw_networkx(
    G=g,
    node_color=COLORS[2],
    node_size=1200,
    arrowsize=17,
    with_labels=True,
    font_color='white',
    font_size=21,
    pos=nx.circular_layout(g)
)

## nonlinear

In [None]:
# """
import os
os.environ['CASTLE_BACKEND'] = 'pytorch'

from castle.common import GraphDAG
from castle.metrics import MetricsDAG
from castle.datasets import DAG, IIDSimulation
from castle.algorithms import NotearsNonlinear


#######################################
# notears-nonlinear used simulate data
#######################################
# simulate data for notears-nonlinear
weighted_random_dag = DAG.erdos_renyi(n_nodes=df_causal.shape[1], n_edges=20, weight_range=(0.5, 2.0), seed=1)
# dataset = IIDSimulation(W=weighted_random_dag, n=2000, method='nonlinear', sem_type='mlp')
# true_dag, X = dataset.B, dataset.X

# notears-nonlinear learn
nt = NotearsNonlinear()
nt.learn(df_causal)

# plot est_dag and true_dag
GraphDAG(nt.causal_matrix)

In [None]:
g = nx.DiGraph(nt.causal_matrix)
plt.figure(figsize=(8, 6))
nx.draw_networkx(
    G=g,
    node_color=COLORS[2],
    node_size=1200,
    arrowsize=17,
    with_labels=True,
    font_color='white',
    font_size=21,
    pos=nx.circular_layout(g)
)

##  DAG_gnn

In [None]:
import os
os.environ['CASTLE_BACKEND'] ='pytorch'

from castle.common import GraphDAG
from castle.metrics import MetricsDAG
from castle.datasets import DAG, IIDSimulation
from castle.algorithms import DAG_GNN


type = 'ER'  # or `SF`
h = 2  # ER2 when h=5 --> ER5
n_nodes = df_causal.shape[1]
n_edges = h * n_nodes
method = 'linear'
sem_type = 'gauss'

if type == 'ER':
    weighted_random_dag = DAG.erdos_renyi(n_nodes=n_nodes, n_edges=n_edges,
                                          weight_range=(0.5, 2.0), seed=300)
elif type == 'SF':
    weighted_random_dag = DAG.scale_free(n_nodes=n_nodes, n_edges=n_edges,
                                         weight_range=(0.5, 2.0), seed=300)
else:
    raise ValueError('Just supported `ER` or `SF`.')

# dataset = IIDSimulation(W=weighted_random_dag, n=2000,
#                         method=method, sem_type=sem_type)
# true_dag, X = dataset.B, dataset.X

# rl learn
gnn = DAG_GNN()
gnn.learn(df_causal)

# plot est_dag and true_dag
GraphDAG(gnn.causal_matrix)

In [None]:
g = nx.DiGraph(gnn.causal_matrix)
plt.figure(figsize=(8, 6))
nx.draw_networkx(
    G=g,
    node_color=COLORS[0],
    node_size=1200,
    arrowsize=17,
    with_labels=True,
    font_color='white',
    font_size=21,
    pos=nx.circular_layout(g)
)

## MCSL

In [None]:
from castle.common import GraphDAG
from castle.metrics import MetricsDAG
from castle.datasets import DAG, IIDSimulation
from castle.algorithms import MCSL


#######################################
# mcsl used simulate data
#######################################
# simulate data for mcsl
weighted_random_dag = DAG.erdos_renyi(n_nodes=df_causal.shape[1], n_edges=20, weight_range=(0.5, 2.0), seed=1)
# dataset = IIDSimulation(W=weighted_random_dag, n=2000, method='nonlinear', sem_type='mlp')
# true_dag, X = dataset.B, dataset.X

# mcsl learn
mc = MCSL(model_type='nn',
          iter_step=10000,
          rho_thresh=1e20,
          init_rho=1e-5,
          rho_multiply=10,
          graph_thresh=0.5,
          l1_graph_penalty=2e-3)
mc.learn(df_causal)

# plot est_dag and true_dag
GraphDAG(mc.causal_matrix)

In [None]:
g = nx.DiGraph(mc.causal_matrix)
plt.figure(figsize=(8, 6))
nx.draw_networkx(
    G=g,
    node_color=COLORS[2],
    node_size=1200,
    arrowsize=17,
    with_labels=True,
    font_color='white',
    font_size=21,
    pos=nx.circular_layout(g)
)

In [None]:
g = nx.DiGraph(nt.causal_matrix)
plt.figure(figsize=(8, 6))
nx.draw_networkx(
    G=g,
    node_color=COLORS[2],
    node_size=1200,
    arrowsize=17,
    with_labels=True,
    font_color='white',
    font_size=21,
    pos=nx.circular_layout(g)
)

## **RL**

In [None]:

from castle.common import GraphDAG
from castle.metrics import MetricsDAG
from castle.datasets import DAG, IIDSimulation
from castle.algorithms import RL
import pandas as pd

#######################################
# rl used simulate data
#######################################
# simulate data for rl
weighted_random_dag = DAG.erdos_renyi(n_nodes=df_causal.shape[1], n_edges=df_causal.shape[1]*2, weight_range=(0.5, 2.0), seed=1)
# dataset = IIDSimulation(W=weighted_random_dag, n=2000, method='linear', sem_type='gauss')
# true_dag, X = dataset.B, dataset.X

# rl learn
rl = RL(nb_epoch=2000)

rl.learn(df_causal)

# plot est_dag and true_dag
GraphDAG(rl.causal_matrix)

In [None]:
g = nx.DiGraph(rl.causal_matrix)
plt.figure(figsize=(8, 6))
nx.draw_networkx(
    G=g,
    node_color=COLORS[0],
    node_size=1200,
    arrowsize=17,
    with_labels=True,
    font_color='white',
    font_size=21,
    pos=nx.circular_layout(g)
)

## GRAN_dag

In [None]:
from castle.common import GraphDAG
from castle.metrics import MetricsDAG
from castle.datasets import DAG, IIDSimulation
from castle.algorithms import GraNDAG
import pandas as pd
# load data
weighted_random_dag = DAG.erdos_renyi(n_nodes=df_causal.shape[1], n_edges=df_causal.shape[1]*2,
                                      weight_range=(0.5, 2.0), seed=1)
# dataset = IIDSimulation(W=weighted_random_dag, n=2000, method='nonlinear',
#                         sem_type='mlp')
# dag, x = dataset.B, dataset.X

# Instantiation algorithm
d = {'model_name': 'NonLinGauss', 'nonlinear': 'leaky-relu', 'optimizer': 'sgd', 'norm_prod': 'paths', 'device_type': 'gpu'}
gnd = GraNDAG(input_dim=x.shape[1], )

# gnd.learn(data=data)
gnd.learn(df_causal)

# plot predict_dag and true_dag
GraphDAG(gnd.causal_matrix, dag)

## MCSL

In [None]:
from castle.metrics import MetricsDAG
from castle.datasets import DAG, IIDSimulation
from castle.algorithms import MCSL


#######################################
# mcsl used simulate data
#######################################
# simulate data for mcsl
weighted_random_dag = DAG.erdos_renyi(n_nodes=df_causal.shape[1], n_edges=df_causal.shape[1]*2, weight_range=(0.5, 2.0), seed=1)
# dataset = IIDSimulation(W=weighted_random_dag, n=2000, method='nonlinear', sem_type='mlp')
# true_dag, X = dataset.B, dataset.X

# mcsl learn
mc = MCSL(model_type='nn',
          iter_step=100,
          rho_thresh=1e20,
          init_rho=1e-5,
          rho_multiply=10,
          graph_thresh=0.5,
          l1_graph_penalty=2e-3)

mc.learn(X)

# plot est_dag and true_dag
GraphDAG(mc.causal_matrix)

## GAE

In [None]:
from castle.common import GraphDAG
from castle.metrics import MetricsDAG
from castle.datasets import DAG, IIDSimulation
from castle.algorithms import GAE


#######################################
# graph_auto_encoder used simulate data
#######################################
# simulate data for graph-auto-encoder
weighted_random_dag = DAG.erdos_renyi(n_nodes=df_causal.shape[1], n_edges=df_causal.shape[1]*2, weight_range=(0.5, 2.0), seed=1)
# dataset = IIDSimulation(W=weighted_random_dag, n=2000, method='linear', sem_type='gauss')
# true_dag, X = dataset.B, dataset.X

ga = GAE(input_dim=10)
ga.learn(df_causal)

# plot est_dag and true_dag
GraphDAG(ga.causal_matrix)

## CORL

In [None]:

import os
os.environ['CASTLE_BACKEND'] ='pytorch'

from castle.common import GraphDAG
from castle.metrics import MetricsDAG
from castle.datasets import DAG, IIDSimulation
from castle.algorithms import CORL

type = 'ER'  # or `SF`
h = 2  # ER2 when h=5 --> ER5
n_nodes = df_causal.shape[1]
n_edges = h * n_nodes
method = 'linear'
sem_type = 'gauss'

if type == 'ER':
    weighted_random_dag = DAG.erdos_renyi(n_nodes=n_nodes, n_edges=n_edges,
                                          weight_range=(0.5, 2.0), seed=300)
elif type == 'SF':
    weighted_random_dag = DAG.scale_free(n_nodes=n_nodes, n_edges=n_edges,
                                         weight_range=(0.5, 2.0), seed=300)
else:
    raise ValueError('Just supported `ER` or `SF`.')

# dataset = IIDSimulation(W=weighted_random_dag, n=2000,
#                         method=method, sem_type=sem_type)
# true_dag, X = dataset.B, dataset.X

# rl learn
rl = CORL(encoder_name='transformer',
          decoder_name='lstm',
          reward_mode='episodic',
          reward_regression_type='LR',
          batch_size=64,
          input_dim=64,
          embed_dim=64,
          iteration=2000,
          device_type='GPU')
rl.learn(df_causal)

# plot est_dag and true_dag
GraphDAG(rl.causal_matrix)

## Other method

DECI IN Code

In [None]:
from dataclasses import dataclass

import numpy as np
import networkx as nx

import torch
import pytorch_lightning as pl

from torch.utils.data import DataLoader
from tensordict import TensorDict

from castle.datasets import DAG, IIDSimulation
from castle.common import GraphDAG
from castle.metrics import MetricsDAG

import causica.distributions as cd

from causica.functional_relationships import ICGNN
from causica.training.auglag import AugLagLossCalculator, AugLagLR, AugLagLRConfig
from causica.graph.dag_constraint import calculate_dagness

from causica.datasets.variable_types import VariableTypeEnum
from causica.datasets.tensordict_utils import tensordict_shapes

import matplotlib.pyplot as plt
plt.style.use('fivethirtyeight')

In [None]:
# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
COLORS = [
    '#00B0F0',
    '#FF0000',
    '#B0F000'
]

In [None]:
# Set random seed
SEED = 11
np.random.seed(SEED)
pl.seed_everything(SEED)

In [None]:
nx.__version__

In [None]:
# Generate a scale-free adjacency matrix
adj_matrix = DAG.scale_free(
    n_nodes=4,
    n_edges=6,
    seed=SEED
)

# Generate the simulation
dataset = IIDSimulation(
    W=adj_matrix,
    n=5000,
    method='nonlinear',
    sem_type='mim'
)

In [None]:
# Plot the graph
g = nx.DiGraph(adj_matrix)

plt.figure(figsize=(4, 3))
nx.draw(
    G=g,
    node_color=COLORS[0],
    node_size=1200,
    arrowsize=17,
    with_labels=True,
    font_color='white',
    font_size=21,
    pos=nx.circular_layout(g)
)

In [None]:
# Training config
@dataclass(frozen=True)
class TrainingConfig:
    noise_dist=cd.ContinuousNoiseDist.SPLINE
    batch_size=512
    max_epoch=500
    gumbel_temp=0.25
    averaging_period=10
    prior_sparsity_lambda=5.0
    init_rho=1.0
    init_alpha=0.0

training_config = TrainingConfig()
auglag_config = AugLagLRConfig()

In [None]:
# Cast data to torch tensors
data_tensors = {}

for i in range(dataset.X.shape[1]):
    data_tensors[f'x{i}'] = torch.tensor(dataset.X[:, i].reshape(-1, 1))

dataset_train = TensorDict(data_tensors, torch.Size([dataset.X.shape[0]]))

# Move the entire dataset to the device (for big datasets move to device by batch within training loop)
dataset_train = dataset_train.apply(lambda t: t.to(dtype=torch.float32, device=device))

# Create loader
dataloader_train = DataLoader(
    dataset=dataset_train,
    collate_fn=lambda x: x,
    batch_size=training_config.batch_size,
    shuffle=True,
    drop_last=False,
)

dataset_train

In [None]:
# Plot the true adj matrix
plt.style.use('default')
GraphDAG(adj_matrix)
plt.show()

In [None]:
# Encode our strong belief about the existence of the edge (3, 0)
# And lack of existence of edge (0, 3)
expert_matrix = torch.tensor(np.zeros(adj_matrix.shape))

# Encode the edge knowledge
expert_matrix[3, 0] = 1.

# Create a relevancew mask
relevance_mask = expert_matrix.clone()
relevance_mask[0, 3] = 1.

# Create a confidence matrix
confidence_matrix = relevance_mask.clone()

In [None]:
# Encapsulate everything within the expert knowledge container
expert_knowledge = cd.ExpertGraphContainer(
    dag=expert_matrix,
    mask=relevance_mask,
    confidence=confidence_matrix,
    scale=5.
)

In [None]:
num_nodes = len(dataset_train.keys())

# Define the prior
prior = cd.GibbsDAGPrior(
    num_nodes=num_nodes,
    sparsity_lambda=training_config.prior_sparsity_lambda,
    expert_graph_container=expert_knowledge
)

In [None]:
# Define the adjaceny module
adjacency_dist = cd.ENCOAdjacencyDistributionModule(num_nodes)

#Define the functional module
icgnn = ICGNN(
    variables=tensordict_shapes(dataset_train),
    embedding_size=8, #32,
    out_dim_g=8, #32,
    norm_layer=torch.nn.LayerNorm,
    res_connection=True,
)

# Define the noise module
types_dict = {var_name: VariableTypeEnum.CONTINUOUS for var_name in dataset_train.keys()}

noise_submodules = cd.create_noise_modules(
    shapes=tensordict_shapes(dataset_train),
    types=types_dict,
    continuous_noise_dist=training_config.noise_dist
)

noise_module = cd.JointNoiseModule(noise_submodules)

In [None]:
noise_submodules

In [None]:
?cd.SEMDistributionModule

In [None]:
# Combine all SEM modules
sem_module = cd.SEMDistributionModule(
    adjacency_module=adjacency_dist,
    functional_relationships=icgnn,
    noise_module=noise_module)

sem_module.to(device)

In [None]:
modules = {
    "icgnn": sem_module.functional_relationships,
    "vardist": sem_module.adjacency_module,
    "noise_dist": sem_module.noise_module,
}

parameter_list = [
    {"params": module.parameters(), "lr": auglag_config.lr_init_dict[name], "name": name}
    for name, module in modules.items()
]

# Define the optimizer
optimizer = torch.optim.Adam(parameter_list)

In [None]:
# Define the augmented Lagrangian loss objects
scheduler = AugLagLR(config=auglag_config)

auglag_loss = AugLagLossCalculator(
    init_alpha=training_config.init_alpha,
    init_rho=training_config.init_rho
)

In [None]:
assert len(dataset_train.batch_size) == 1, "Only 1D batch size is supported"

num_samples = len(dataset_train)

for epoch in range(training_config.max_epoch):

    for i, batch in enumerate(dataloader_train):

        # Zero the gradients
        optimizer.zero_grad()

        # Get SEM
        sem_distribution = sem_module()
        sem, *_ = sem_distribution.relaxed_sample(
            torch.Size([]),
            temperature=training_config.gumbel_temp
        )  # soft sample

        # Compute the log probability of data
        batch_log_prob = sem.log_prob(batch).mean()

        # Get the distribution entropy
        sem_distribution_entropy = sem_distribution.entropy()

        # Compute the likelihood of the current graph
        prior_term = prior.log_prob(sem.graph)

        # Compute the objective
        objective = (-sem_distribution_entropy - prior_term) / num_samples - batch_log_prob

        # Compute the DAG-ness term
        constraint = calculate_dagness(sem.graph)

        # Compute the Lagrangian loss
        loss = auglag_loss(objective, constraint / num_samples)

        # Propagate gradients and update
        loss.backward()
        optimizer.step()

        # Update the Auglag parameters
        scheduler.step(
            optimizer=optimizer,
            loss=auglag_loss,
            loss_value=loss.item(),
            lagrangian_penalty=constraint.item(),
        )

        # Log metrics & plot the matrices
        if epoch % 10 == 0 and i == 0:
            print(
                f"epoch:{epoch} loss:{loss.item():.5g} nll:{-batch_log_prob.detach().cpu().numpy():.5g} "
                f"dagness:{constraint.item():.5f} num_edges:{(sem.graph > 0.0).sum()} "
                f"alpha:{auglag_loss.alpha:.5g} rho:{auglag_loss.rho:.5g} "
                f"step:{scheduler.outer_opt_counter}|{scheduler.step_counter} "
                f"num_lr_updates:{scheduler.num_lr_updates}"
            )

            vardist = adjacency_dist()
            pred_dag = vardist.mode.cpu().numpy()

            plt.style.use('default')

            GraphDAG(
                est_dag=pred_dag,
                true_dag=adj_matrix)

            plt.show()

In [None]:
# Sample from the distribution of graphs
vardist = adjacency_dist()
pred_dag = vardist.mode.cpu().numpy()


# Plot the final graph vs the ground truth
plt.style.use('default')

GraphDAG(
    est_dag=pred_dag,
    true_dag=adj_matrix)

plt.show()

In [None]:
# Compute and print the metrics
metrics = MetricsDAG(
    B_est=pred_dag,
    B_true=adj_matrix)

metrics.metrics

**FCI**

In [None]:
from causallearn.search.ConstraintBased.FCI import fci
from causallearn.graph.GraphNode import GraphNode
from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge

In [None]:
# Generate confounded data
N = 1000

q = np.random.uniform(0, 2, N)
w = np.random.randn(N)
x = np.random.gumbel(0, 1, N) + w
y = 0.6 * q + 0.8 * w + np.random.uniform(0, 1, N)
z = 0.5 * x + np.random.randn(N)

data = np.stack([x, y, w, z, q]).T
confounded_data = np.stack([x, y, z, q]).T

In [None]:
# Create the true graph
nodes = ['X', 'Y', 'W', 'Z', 'Q']

edges = [
    ('W', 'X'),
    ('W', 'Y'),
    ('Q', 'Y'),
    ('X', 'Z'),
]

fci_graph = nx.DiGraph()

fci_graph.add_nodes_from(nodes)
fci_graph.add_edges_from(edges)

In [None]:
# Plot the graph
plt.figure(figsize=(4, 3))

nx.draw_networkx(
    G=fci_graph,
    node_color=COLORS[0],
    node_size=1200,
    nodelist=['X', 'Y', 'Z', 'Q'],
    arrowsize=17,
    with_labels=True,
    font_color='white',
    font_size=21,
    pos=nx.circular_layout(fci_graph)
)
nx.draw_networkx(
    G=fci_graph,
    node_color=COLORS[1],
    node_size=1200,
    nodelist=['W'],
    arrowsize=17,
    with_labels=True,
    font_color='white',
    font_size=21,
    pos=nx.circular_layout(fci_graph)
)

In [None]:
df_causal = df_causal.drop(columns=['slope_ttc'])

### Model

In [None]:
# Train FCI and get the graph
g, edges = fci(
    dataset=df_causal.values,
    independence_test_method='kci'
)

In [None]:
# Ordering: [x, y, z, q]
g.graph
GraphDAG(g.causal_matrix)

In [None]:
mapping = {
    'X1': 'X',
    'X2': 'Y',
    'X3': 'Z',
    'X4': 'Q'
}

for edge in edges:
    mapped = str(edge)\
        .replace(str(edge.node1), mapping[str(edge.node1)])\
        .replace(str(edge.node2), mapping[str(edge.node2)])
    print(mapped)

In [None]:
str(edge), str(edge.node1)

## Model with prior knowledge

In [None]:
# Add prior knowledge
prior_knowledge = BackgroundKnowledge()
prior_knowledge.add_forbidden_by_node(GraphNode('X2'), GraphNode('X4'))
prior_knowledge.add_required_by_node(GraphNode('X1'), GraphNode('X3'))

g, edges = fci(
    dataset=confounded_data,
    independence_test_method='fisherz',
    background_knowledge=prior_knowledge
)

In [None]:
for edge in edges:
    mapped = str(edge)\
        .replace(str(edge.node1), mapping[str(edge.node1)])\
        .replace(str(edge.node2), mapping[str(edge.node2)])
    print(mapped)