In [None]:
# add parent directory to path (to use ocd package)
import sys

sys.path.append("..")

# set PYTORCH_ENABLE_MPS_FALLBACK=1 in current environment to enable MPS fallback
import os

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

# import packages
import torch
import lightning
import lightning.pytorch.callbacks
from ocd.training import OrderedTrainingModule


In [None]:
# set callbacks for the trainer
callbacks = [
    # monitor the learning rate (log to tensorboard)
    lightning.pytorch.callbacks.LearningRateMonitor(logging_interval="epoch"),
]

trainer = lightning.Trainer(
    accelerator="mps",  # remove this line to run on CPU
    callbacks=callbacks,
    # precision=16, # for mixed precision training
    gradient_clip_val=1.0,
    gradient_clip_algorithm="value",
    track_grad_norm="inf",
    log_every_n_steps=1,
    # detect_anomaly=True,
)


In [None]:
# setup data
from ocd.data import CausalDataModule
import dycode
import torch

dycode.register_context(torch)

dm = CausalDataModule(
    name="alarm",  # small dataset asia
    observation_size=2048,  # number of observation samples
    intervention_size=256,  # set to 0 for no intervention
    batch_size=128,
    num_workers=0,  # set to 0 for no multiprocessing
    val_size=0,  # 10% of data for validation, or use int for exact number of samples, set to 0 for no validation
    pin_memory=True,  # set to True for faster data transfer to GPU (if available)
)
dm.setup("fit")


In [None]:
# in_features = dm.train_data[0].dataset.features_values
# if val_size = 0, then use the following line instead of the above line
in_features = dm.train_data[0].features_values


In [None]:
# torch.autograd.set_detect_anomaly(True)
tm = OrderedTrainingModule(
    embedding_dim=8,
    embedding_normalization=2,
    in_covariate_features=in_features,
    hidden_features_per_covariate=[
        [64 for i in range(len(in_features))],
        [32 for i in range(len(in_features))],
        [8 for i in range(len(in_features))],
    ],
    batch_norm=False,
    criterion_args=dict(
        terms=[
            "ocd.training.terms.OrderedLikelihoodTerm",
            # dict(
            #     name="norm(gamma)",
            #     term_function='lambda training_module: training_module.model.Gamma.norm(float("inf"))',
            #     factor=0,
            # ),
            # dict(
            #     name='nothing',
            #     term_function='def term(training_module, batch):\n\ttraining_module.batch=batch\n\treturn torch.zeros(1, device=batch.device)',
            #     factor=0,
            # )
        ]
    ),
    optimizer="torch.optim.AdamW",
    tau_scheduler="lambda training_module: 1 * (0.98 ** training_module.current_epoch if training_module.current_epoch < 1000 else 0.01)",
    lr=0.001,
    scheduler="torch.optim.lr_scheduler.ExponentialLR",
    scheduler_interval="epoch",
    scheduler_args={"gamma": 0.99},
)


In [None]:
trainer.fit(tm, dm)
