In [1]:
import sys
sys.path.append('..')

In [1]:
import lightning
import lightning.pytorch.callbacks
from ocd.training import OrderedTrainingModule


In [2]:
%load_ext autoreload

In [3]:
%autoreload 2
# set callbacks for the trainer
callbacks = [
    # monitor the learning rate (log to tensorboard)
    lightning.pytorch.callbacks.LearningRateMonitor(logging_interval="epoch"),
]

trainer = lightning.Trainer(
    # accelerator="mps",  # remove this line to run on CPU
    callbacks=callbacks,
    # precision=16, # for mixed precision training
    # gradient_clip_val=1.0,
    # gradient_clip_algorithm="value",
    max_epochs=10000,
    track_grad_norm="inf",
    log_every_n_steps=1,
    # overfit_batches=3,
    # detect_anomaly=True,
)


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [4]:
# setup data
from ocd.data import CausalDataModule
import dycode
import torch

dycode.register_context(torch)

dm = CausalDataModule(
    name="asia",  # small dataset asia
    observation_size=4096,  # number of observation samples
    intervention_size=256,  # set to 0 for no intervention
    batch_size=128,
    num_workers=0,  # set to 0 for no multiprocessing
    val_size=0,  # 10% of data for validation, or use int for exact number of samples, set to 0 for no validation
    pin_memory=True,  # set to True for faster data transfer to GPU (if available)
)
dm.setup("fit")


[bnlearn] >Import <asia>
[bnlearn] >Loading bif file </home/hamidreza/Work/myprojects/ocd/venv/lib/python3.10/site-packages/bnlearn/data/asia.bif>
[bnlearn] >Check whether CPDs sum up to one.
[bnlearn] >Check whether CPDs associated with the nodes are consistent: True


In [5]:
# set torch.anomaly_detection(True) to debug
import torch

# Extract the category sizes
in_features = dm.train_data[0].features_values


# torch.autograd.set_detect_anomaly(True)
tm = OrderedTrainingModule(
    in_covariate_features=in_features,
    hidden_features_per_covariate=[
        [128 for i in range(len(in_features))],
        [64 for i in range(len(in_features))],
        [32 for i in range(len(in_features))],
    ],
    bias=False,
    batch_norm=False,
    criterion_args=dict(
        terms=[
            "ocd.training.terms.OrderedLikelihoodTerm",
            dict(
                name="norm(gamma)",
                term_function='lambda training_module: training_module.model.Gamma.norm(float("inf"))',
                factor=0,
            ),
            dict(
                name="norm(layers)",
                term_function='lambda training_module: max([layer.linear.weight.norm(float("inf")) for layer in training_module.model.made.layers])',
                factor=0,
            )
            # dict(
            #     name='nothing',
            #     term_function='def term(training_module, batch):\n\ttraining_module.batch=batch\n\treturn torch.zeros(1, device=batch.device)',
            #     factor=0,
            # )
        ]
    ),
    optimizer=['torch.optim.Adam', 'torch.optim.Adam'],
    optimizer_parameters=['model.made', 'model.Gamma'],
    optimizer_args=[
        dict(
            weight_decay=0.0001,
        ),
        dict()
    ],
    optimizer_is_active=[
        'lambda training_module: training_module.current_epoch % 10 < 10',
        'lambda training_module: training_module.current_epoch % 10 < 10',
    ],
    tau_scheduler="lambda training_module: max(0.003, 0.5 * 0.98 ** (training_module.current_epoch // 1))",
    n_sinkhorn_scheduler="lambda training_module: min(60, max(20, 20 + ((training_module.current_epoch - 20) // 10)))",
    lr=0.001,
    scheduler="torch.optim.lr_scheduler.ExponentialLR",
    scheduler_interval="epoch",
    scheduler_args={"gamma": 0.999},
)


In [10]:
trainer.fit(tm, dm)


[bnlearn] >Import <asia>
[bnlearn] >Loading bif file </home/hamidreza/Work/myprojects/ocd/venv/lib/python3.10/site-packages/bnlearn/data/asia.bif>
[bnlearn] >Check whether CPDs sum up to one.
[bnlearn] >Check whether CPDs associated with the nodes are consistent: True



  | Name  | Type                   | Params
-------------------------------------------------
0 | model | SinkhornOrderDiscovery | 675 K 
-------------------------------------------------
675 K     Trainable params
0         Non-trainable params
675 K     Total params
2.704     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 20it [00:00, ?it/s]

In [9]:
print(dm.trainer.datamodule.datasets[2].dag)

[[0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 1.]
 [0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 1. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0.]]
