# Train an OCD MADE and evaluate its performance

In this notebook, we will train an Ordered Causal Discovery (OCD) model based on the MADE architecture and evaluate its performance on the set dataset from BNlearn.

We use pytorch-lightning to train and evaluate our models. The code for the MADE architecture is based on the code from the `torchde` library but modified to support changing the order of the variables with permutation matrices.

## Setup Data
### Import libraries 

In [1]:
# add parent directory to path (to use ocd package)
import sys

sys.path.append("..")

# set PYTORCH_ENABLE_MPS_FALLBACK=1 in current environment to enable MPS fallback
import os

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

# import packages
import dycode  # for dynamic code execution
import numpy as np
import torch
import lightning
import lightning.pytorch.callbacks
from ocd.training import OrderedTrainingModule
from ocd.data import CausalDataModule  # for loading data
from ocd.training.terms import OrderedLikelihoodTerm, PermanentMatrixPenalizer
from ocd.models.utils import log_prob
from ocd.evaluation import shd
from ocd.post_processing.pruning import prune, PruningMethod


# register torch in dynamic code context
dycode.register_context(torch)


## Setup Data
You can mention the name of your desired dataset from `bnlearn` in the `name` variable. If you wish to use a dataset that is not available by stock `bnlearn`, you can just mention the download link of the `.bif` file as the `name` variable. This will download the dataset and use it to create the observational and interventional samples for you.

As a side note, intervention size is the number of samples per value of the intervened variable. For example if a node takes 3 values and the intervention size is 100, then the number of samples for the intervened node will be 300.


In [2]:
# define the data
dm = CausalDataModule(
    name="asia",  # change this to your desired dataset from bnlearn
    observation_size=2048,  # number of observation samples
    intervention_size=256,  # set to 0 for no intervention
    batch_size=128,  # batch size for training (and validation if val_size > 0)
    num_workers=0,  # set to 0 for no multiprocessing
    val_size=0,  # set to 0 for no validation, fraction for relative validation size, or int for absolute validation size
    pin_memory=True,  # set to True for faster data transfer to GPU (if available)
)

# read general stats of the data to setup the model later
dm.setup("fit")  # prepare the data to read general stats (e.g. number of features and nodes)
in_features = dm.train_data[0].features_values  # number of features (used for setting the input size of the model)


## Train the model

In [3]:
# set the model and training parameters
tm = OrderedTrainingModule(
    # input parameters
    in_covariate_features=in_features,  # number of nodes and features in the data
    # network architecture
    hidden_features_per_covariate=[
        [32 for i in range(len(in_features))],
        [16 for i in range(len(in_features))],
        [8 for i in range(len(in_features))],
    ],
    batch_norm=False,
    # training objective
    criterion_args=dict(
        terms=[
            "ocd.training.terms.OrderedLikelihoodTerm",
            # PermanentMatrixPenalizer(factor=1), # uncomment this line to add a permanent matrix penalizer
        ],
    ),
    # training parameters
    optimizer=["torch.optim.AdamW", "torch.optim.SGD"],
    optimizer_parameters=["model.made", "model._gamma"],
    optimizer_is_active=[
        "lambda training_module: True",
        "lambda training_module: True",
    ],
    tau_scheduler="lambda training_module: max(0.0005, 0.01 * 0.5 ** (training_module.current_epoch // 5))",
    n_sinkhorn_scheduler="lambda training_module: 20",
    lr=[0.001, 1],
    scheduler="torch.optim.lr_scheduler.ExponentialLR",
    scheduler_interval="epoch",
    scheduler_args={"gamma": 0.999},
    # log
    # log_permutation=True, # uncomment this line to log the permutation matrix
    # log_permutation_freq=1,
)


### Setup trainer and fit the model to the data

In [4]:
# set callbacks for the trainer
callbacks = [
    # monitor the learning rate (log to tensorboard)
    lightning.pytorch.callbacks.LearningRateMonitor(logging_interval="epoch"),
]

trainer = lightning.Trainer(
    # accelerator="mps",  # remove this line to run on CPU
    callbacks=callbacks,
    # precision=16, # for mixed precision training
    # gradient_clip_val=1.0,
    # gradient_clip_algorithm="value",
    track_grad_norm="inf",
    log_every_n_steps=1,
    # detect_anomaly=True,
    max_epochs=100,
)

# train the model
trainer.fit(tm, dm)


GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type                   | Params
-------------------------------------------------
0 | model | SinkhornOrderDiscovery | 46.6 K
-------------------------------------------------
46.6 K    Trainable params
0         Non-trainable params
46.6 K    Total params
0.186     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


### Prune the learned model to remove redundant edges using observational data

In [5]:
ground_truth = dm.datasets[0].dag  # ground truth DAG

pruned_dag = prune(
    ordering=tm.get_ordering(),  # learned ordering
    data=dm.datasets[0].samples,  # observational data
    method=PruningMethod.CONDITIONAL_INDEPENDENCE_TESTING,  # pruning method
    verbose=1,  # print progress
    method_params=dict(threshold=0.05),  # confidence threshold for conditional independence testing
)

print("The number of edges in the original DAG is:", np.sum(ground_truth).astype(int))
print("Structural hamming distance between pruned_dag and original_dag is:", shd(pruned_dag, ground_truth))

100%|██████████| 8/8 [00:00<00:00, 53.90it/s]

The number of edges in the original DAG is: 8
Structural hamming distance between pruned_dag and original_dag is: 8





## Further prune using interventional data
As mentioned in the [pruning notebook](./pruning.ipynb) we can further prune the model using interventional data. We should just make sure that we do not prune the parents of the intervened variable.
As we are dealing with unknown interventions, we can treat the node with minimum average log-likelihood over the whole interventional episode (dataset). 

In [6]:
ground_truth = dm.datasets[0].dag
count_incorrect = 0
# get the interventional dataframes with the values of the intervention
for intervention_eposide in dm.datasets[1:]:
    intervention_node_name = intervention_eposide.intervention_node
    # check the index of the intervention node in intervention_eposide.samples.columns
    gt_intervention_node_index = intervention_eposide.samples.columns.get_loc(intervention_node_name)
    # find the node being intervened on (it should have the lowest log_prob in the episode)
    dataloader = torch.utils.data.DataLoader(intervention_eposide, batch_size=32, shuffle=False)

    # find the node with the lowest log_prob on average over the episode
    log_probs = []
    with torch.no_grad():
        for batch in dataloader:
            batch = batch.to(tm.device)
            processed_batch = tm.process_batch(batch).to(tm.device)
            logits = tm.model(processed_batch)
            log_probs.append(log_prob(logits, in_features, batch, reduce=False))

    # index of the node with the lowest log_prob
    intervention_node_index = torch.cat(log_probs).mean(0).argmin().item()

    # count the number of misclassifications
    count_incorrect = count_incorrect + (intervention_node_index != gt_intervention_node_index)

    # prune the resulting DAG
    pruned_dag = prune(
        ordering=tm.get_ordering(),  # learned ordering
        data=intervention_eposide.samples,  # interventional data
        method=PruningMethod.CONDITIONAL_INDEPENDENCE_TESTING,  # pruning method
        dag=pruned_dag,  # DAG to prune (result of previous prunings)
        interventional_column=intervention_node_index,  # index of the node being intervened on
        verbose=1,  # print progress
        method_params=dict(threshold=0.05),  # confidence threshold for conditional independence testing
    )

print("Number of misclassification of nodes under intervention:", count_incorrect)
print("The number of edges in the original DAG is: ", np.sum(ground_truth).astype(int))
print("Structural hamming distance between pruned_dag and original_dag is: ", shd(pruned_dag, ground_truth))


100%|██████████| 8/8 [00:00<00:00, 353.37it/s]
100%|██████████| 8/8 [00:00<00:00, 396.63it/s]
100%|██████████| 8/8 [00:00<00:00, 363.57it/s]
100%|██████████| 8/8 [00:00<00:00, 349.67it/s]
100%|██████████| 8/8 [00:00<00:00, 412.26it/s]
100%|██████████| 8/8 [00:00<00:00, 575.51it/s]
100%|██████████| 8/8 [00:00<00:00, 812.85it/s]
100%|██████████| 8/8 [00:00<00:00, 1092.16it/s]

Number of misclassification of nodes under intervention: 3
The number of edges in the original DAG is:  8
Structural hamming distance between pruned_dag and original_dag is:  7



