In [None]:
# add parent directory to path (to use ocd package)
import sys

sys.path.append("..")

# set PYTORCH_ENABLE_MPS_FALLBACK=1 in current environment to enable MPS fallback
import os

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

# import packages
import torch
import lightning
import lightning.pytorch.callbacks
from ocd.training import OrderedTrainingModule
from ocd.training.terms import OrderedLikelihoodTerm, PermanentMatrixPenalizer
from ocd.data import CausalDataModule # for loading data
import dycode # for dynamic code execution
import torch

# register torch in dynamic code context
dycode.register_context(torch)



In [None]:
# set callbacks for the trainer
callbacks = [
    # monitor the learning rate (log to tensorboard)
    lightning.pytorch.callbacks.LearningRateMonitor(logging_interval="epoch"),
]

trainer = lightning.Trainer(
    # accelerator="mps",  # remove this line to run on CPU
    callbacks=callbacks,
    # precision=16, # for mixed precision training
    # gradient_clip_val=1.0,
    # gradient_clip_algorithm="value",
    track_grad_norm="inf",
    log_every_n_steps=1,
    # detect_anomaly=True,
)


In [None]:
# define the data 
dm = CausalDataModule(
    name="https://www.bnlearn.com/bnrepository/survey/survey.bif.gz", 
    observation_size=2048,  # number of observation samples
    intervention_size=256,  # set to 0 for no intervention
    batch_size=128,
    num_workers=0,  # set to 0 for no multiprocessing
    val_size=0,  # 10% of data for validation, or use int for exact number of samples, set to 0 for no validation
    pin_memory=True,  # set to True for faster data transfer to GPU (if available)
)

# read general stats of the data to setup the model later
dm.setup("fit") # prepare the data to read general stats (e.g. number of features and nodes)
in_features = dm.train_data[0].features_values # number of features (used for setting the input size of the model)


In [None]:
# set the model and training parameters
tm = OrderedTrainingModule(
    # input parameters
    in_covariate_features=in_features,  # number of nodes and features in the data
    # network architecture
    hidden_features_per_covariate=[
        [32 for i in range(len(in_features))],
        [16 for i in range(len(in_features))],
        [8 for i in range(len(in_features))],
    ],
    batch_norm=False,
    # training objective
    criterion_args=dict(
        terms=[
            "ocd.training.terms.OrderedLikelihoodTerm",
            # PermanentMatrixPenalizer(factor=1),
        ],
        # regularizations=[
        #     dict(
        #         name="nothing",
        #         term_function="lambda batch: torch.zeros(1, device=batch[0].device)",
        #         factor="def factor(training_module, results_dict):\n\ttraining_module.loss = results_dict['loss']\n\treturn 0",
        #     )
        # ],
    ),
    # training parameters
    optimizer=['torch.optim.AdamW', 'torch.optim.SGD'],
    optimizer_parameters=['model.made', 'model._gamma'],
    optimizer_is_active=[
        'lambda training_module: True',
        'lambda training_module: True',
    ],
    tau_scheduler="lambda training_module: max(0.0005, 0.01 * 0.5 ** (training_module.current_epoch // 5))",
    # tau_scheduler="lambda training_module: 0.001",
    # n_sinkhorn_scheduler="lambda training_module: min(60, max(20, 20 + ((training_module.current_epoch - 20) // 10)))",
    n_sinkhorn_scheduler="lambda training_module: 20",
    lr=[0.001, 1],
    scheduler="torch.optim.lr_scheduler.ExponentialLR",
    scheduler_interval="epoch",
    scheduler_args={"gamma": 0.999},
    # log
    # log_permutation=True,
    # log_permutation_freq=1,
)


In [None]:
# train the model
trainer.fit(tm, dm)


In [None]:
# prune the learned ordering to obtain a DAG
from ocd.post_processing.pruning import prune, PruningMethod
from ocd.evaluation import shd
import numpy as np

ground_truth = dm.datasets[0].dag  # ground truth DAG

pruned_dag = prune(
    ordering=tm.get_ordering(),
    data=dm.datasets[0].samples,
    method=PruningMethod.CONDITIONAL_INDEPENDENCE_TESTING,
    verbose=1,
    method_params=dict(threshold=0.05),
)

print("The number of edges in the original DAG is:", np.sum(ground_truth).astype(int))
print("Structural hamming distance between pruned_dag and original_dag is:", shd(pruned_dag, ground_truth))


## Further prune using interventional data

In [None]:
from ocd.models.utils import log_prob


In [None]:
ground_truth = dm.datasets[0].dag
count_incorrect = 0
# get the interventional dataframes with the values of the intervention
for intervention_eposide in dm.datasets[1:]:
    intervention_node_name = intervention_eposide.intervention_node
    # check the index of the intervention node in intervention_eposide.samples.columns
    gt_intervention_node_index = intervention_eposide.samples.columns.get_loc(intervention_node_name)
    # find the node being intervened on (it should have the lowest log_prob in the episode)
    dataloader = torch.utils.data.DataLoader(intervention_eposide, batch_size=32, shuffle=False)
    log_probs = []
    with torch.no_grad():
        for batch in dataloader:
            batch = batch.to(tm.device)
            processed_batch = tm.process_batch(batch).to(tm.device)
            logits = tm.model(processed_batch)
            log_probs.append(log_prob(logits, in_features, batch, reduce=False))

    intervention_node_index = torch.cat(log_probs).mean(0).argmin().item()
    count_incorrect = count_incorrect + (intervention_node_index != gt_intervention_node_index)
    pruned_dag = prune(
        ordering=tm.get_ordering(),
        data=intervention_eposide.samples,
        method=PruningMethod.CONDITIONAL_INDEPENDENCE_TESTING,
        dag=pruned_dag,
        interventional_column=intervention_node_index,
        verbose=1,
        method_params={
            "threshold": 0.05,
        },
    )

print('wrong', count_incorrect)
print("The number of edges in the original DAG is: ", np.sum(ground_truth).astype(int))
print("Structural hamming distance between pruned_dag and original_dag is: ", shd(pruned_dag, ground_truth))
