# Sanity checks

Here we will perform a bunch of sanity checks on the model to see if its working as desired or not.

### Checking optimal log-likelihoods per permutation

For the end-to-end model of learning both causal ordering as well as the model parameters to work, different permutations should produce different log-likelihoods with the optimal permutation producing the best likelihood possible.

In [1]:
%load_ext autoreload

In [2]:
%autoreload 2
import sys
sys.path.append('..')


# setup data
from ocd.data import CausalDataModule
import dycode
import torch

# setup model
import lightning
import lightning.pytorch.callbacks
from ocd.training import OrderedTrainingModule


dm = CausalDataModule(
    name="asia",  # small dataset asia
    observation_size=4096,  # number of observation samples
    intervention_size=256,  # set to 0 for no intervention
    batch_size=64,
    num_workers=0,  # set to 0 for no multiprocessing
    val_size=0,  # 10% of data for validation, or use int for exact number of samples, set to 0 for no validation
    pin_memory=True,  # set to True for faster data transfer to GPU (if available)
)
dm.setup("fit")


# Extract the category sizes
in_features = dm.train_data[0].features_values


Setup a fixed permutation below and train the model.

In [3]:
import random
from ocd.utils import topological_sort

prompt = False

# Set the permutation to be used, use the seed to switch between permutations

print("Enter experiment type:")
print("\t[whole] run the whole algorith")
print("\t[correct_order] run the whole algorith")
print("\t[INT] create a random permutation that is fixed and obtained from the input seed")
resp = 'whole' if not prompt else input("enter: ")

VERSION = resp
FIXED_PERMUTATION = None

if resp == 'whole':
    pass
elif resp == 'correct_order':
    FIXED_PERMUTATION = topological_sort(dm.datasets[0].dag)
    print(f"Permutation being used\n{FIXED_PERMUTATION}")
else:
    seed = int(resp)
    VERSION = f'random_fixed_{seed}'
    random.seed(seed)
    FIXED_PERMUTATION = list(range(len(in_features)))
    random.shuffle(FIXED_PERMUTATION)
    print(f"Permutation being used\n{FIXED_PERMUTATION}")

from lightning.pytorch import loggers as pl_loggers

logger = pl_loggers.tensorboard.TensorBoardLogger("lightning_logs", name="sanity_check", version=VERSION)

# set callbacks for the trainer
callbacks = [
    # monitor the learning rate (log to tensorboard)
    lightning.pytorch.callbacks.LearningRateMonitor(logging_interval="epoch"),
]

trainer = lightning.Trainer(
    # accelerator="mps",  # remove this line to run on CPU
    callbacks=callbacks,
    # precision=16, # for mixed precision training
    # gradient_clip_val=1.0,
    # gradient_clip_algorithm="value",
    max_epochs=10000,
    track_grad_norm="inf",
    log_every_n_steps=1,
    logger=logger,
    # overfit_batches=3,
    # detect_anomaly=True,
)
dycode.register_context(torch)


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Enter experiment type:
	[whole] run the whole algorith
	[correct_order] run the whole algorith
	[INT] create a random permutation that is fixed and obtained from the input seed


Setup the training module using the fixed permutation and check the value of loss after covergence.

In [4]:
# set torch.anomaly_detection(True) to debug
import torch

# Extract the category sizes
in_features = dm.train_data[0].features_values


# torch.autograd.set_detect_anomaly(True)
tm = OrderedTrainingModule(
    in_covariate_features=in_features,
    hidden_features_per_covariate=[
        [128 for i in range(len(in_features))],
        [64 for i in range(len(in_features))],
        [32 for i in range(len(in_features))],
    ],
    fixed_permutation=FIXED_PERMUTATION,
    log_permutation=True,
    log_permutation_freq=1,
    batch_norm=False,
    criterion_args=dict(
        terms=[
            "ocd.training.terms.OrderedLikelihoodTerm",
            # "ocd.training.terms.PermanentMatrixPenalizer",
            dict(
                name="norm(gamma)",
                term_function='lambda training_module: training_module.model.Gamma.norm(float("inf"))',
                factor=0,
            ),
            dict(
                name="norm(layers)",
                term_function='lambda training_module: max([layer.linear.weight.norm(float("inf")) for layer in training_module.model.made.layers])',
                factor=0,
            ),
            # dict(
            #     name='nothing',
            #     term_function='def term(training_module, batch):\n\ttraining_module.batch=batch\n\treturn torch.zeros(1, device=batch.device)',
            #     factor=0,
            # )
        ]
    ),
    optimizer=['torch.optim.Adam', 'torch.optim.Adam'],
    optimizer_parameters=['model.made', 'model.Gamma'],
    optimizer_args=[
        dict(
            weight_decay=0.0001,
        ),
        dict()
    ],
    optimizer_is_active=[
        'lambda training_module: training_module.current_epoch % 10 < 10',
        'lambda training_module: training_module.current_epoch % 10 < 10',
    ],
    tau_scheduler="lambda training_module: max(0.003, 0.5 * 0.89 ** (training_module.current_epoch // 1))",
    n_sinkhorn_scheduler="lambda training_module: min(60, max(20, 20 + ((training_module.current_epoch - 20) // 10)))",
    lr=0.001,
    scheduler="torch.optim.lr_scheduler.ExponentialLR",
    scheduler_interval="epoch",
    scheduler_args={"gamma": 0.999},
)


In [5]:
trainer.fit(tm, dm)



  | Name  | Type                   | Params
-------------------------------------------------
0 | model | SinkhornOrderDiscovery | 677 K 
-------------------------------------------------
677 K     Trainable params
0         Non-trainable params
677 K     Total params
2.711     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]