In [None]:
import lightning
import lightning.pytorch.callbacks 
from ocd.training import OrderedTrainingModule

In [None]:
# set callbacks for the trainer
callbacks = [
    # monitor the learning rate (log to tensorboard)
    lightning.pytorch.callbacks.LearningRateMonitor(logging_interval="epoch"),
]

trainer = lightning.Trainer(
    accelerator='mps', # remove this line to run on CPU
    callbacks=callbacks,
    # precision=16, # for mixed precision training
    # gradient_clip_val=1.,
)

In [None]:
# setup data
from ocd.data import CausalDataModule

dm = CausalDataModule(
    name="asia",  # small dataset asia
    observation_size=2048, # number of observation samples
    intervention_size=0,  # set to 0 for no intervention
    batch_size=128, 
    num_workers=0,  # set to 0 for no multiprocessing
    val_size=0.1,  # 10% of data for validation, or use int for exact number of samples, set to 0 for no validation
    pin_memory=True,  # set to True for faster data transfer to GPU (if available)
)
dm.setup("fit")


In [None]:
in_features = dm.train_data[0].dataset.features_values
# if val_size = 0, then use the following line instead of the above line
# in_features = dm.train_data[0].features_values

In [None]:
tm = OrderedTrainingModule(
    in_covariate_features = in_features,
    hidden_features_per_covariate = [
        [128 for i in range(len(in_features))],
        [64 for i in range(len(in_features))],
         [32 for i in range(len(in_features))],
    ],
    optimizer='torch.optim.AdamW',
    tau_scheduler='lambda training_module: 0.09 * (0.9 ** (training_module.current_epoch))',
    lr=0.00001,
    # scheduler='torch.optim.lr_scheduler.ExponentialLR',
    # scheduler_interval='epoch',
    # scheduler_args={'gamma': 0.99},
)

In [None]:
trainer.fit(tm, dm)