In [1]:
import lightning
import lightning.pytorch.callbacks
from ocd.training import OrderedTrainingModule




In [2]:
# set callbacks for the trainer
callbacks = [
    # monitor the learning rate (log to tensorboard)
    lightning.pytorch.callbacks.LearningRateMonitor(logging_interval="epoch"),
]

trainer = lightning.Trainer(
    # accelerator="mps",  # remove this line to run on CPU
    callbacks=callbacks,
    # precision=16, # for mixed precision training
    gradient_clip_val=1.0,
    gradient_clip_algorithm="value",
    track_grad_norm="inf",
    log_every_n_steps=1,
    # detect_anomaly=True,
)


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [13]:
# setup data
from ocd.data import CausalDataModule
import dycode
import torch

dycode.register_context(torch)

dm = CausalDataModule(
    name="asia",  # small dataset asia
    observation_size=2048,  # number of observation samples
    intervention_size=256,  # set to 0 for no intervention
    batch_size=128,
    num_workers=0,  # set to 0 for no multiprocessing
    val_size=0,  # 10% of data for validation, or use int for exact number of samples, set to 0 for no validation
    pin_memory=True,  # set to True for faster data transfer to GPU (if available)
)
dm.setup("fit")


[bnlearn] >Import <asia>
[bnlearn] >Loading bif file </home/hamidreza/.local/lib/python3.8/site-packages/bnlearn/data/asia.bif>
[bnlearn] >Check whether CPDs sum up to one.
[bnlearn] >Check whether CPDs associated with the nodes are consistent: True


In [14]:
# loop over data and find the unique values for each covariate
# each dataset has a samples attribute which is a pandas dataframe
# concatenate all the samples together
import pandas as pd

samples = pd.concat([dataset.samples for dataset in dm.datasets])
# find the unique values for each covariate
unique_values_count = samples.nunique()
# find the the possible values for each covariate
unique_values = samples.apply(lambda x: x.unique())


In [16]:
# in_features = dm.train_data[0].dataset.features_values
# if val_size = 0, then use the following line instead of the above line
in_features = dm.train_data[0].features_values
in_features2 = unique_values_count.values


In [17]:
in_features - in_features2

array([0, 0, 0, 0, 0, 0, 0, 0])

In [18]:
# set torch.anomaly_detection(True) to debug
import torch

# torch.autograd.set_detect_anomaly(True)
tm = OrderedTrainingModule(
    in_covariate_features=in_features,
    hidden_features_per_covariate=[
        [128 for i in range(len(in_features))],
        [64 for i in range(len(in_features))],
        [32 for i in range(len(in_features))],
    ],
    batch_norm=False,
    criterion_args=dict(
        terms=[
            "ocd.training.terms.OrderedLikelihoodTerm",
            dict(
                name="norm(gamma)",
                term_function='lambda training_module: training_module.model.Gamma.norm(float("inf"))',
                factor=0,
            ),
            # dict(
            #     name='nothing',
            #     term_function='def term(training_module, batch):\n\ttraining_module.batch=batch\n\treturn torch.zeros(1, device=batch.device)',
            #     factor=0,
            # )
        ]
    ),
    optimizer="torch.optim.AdamW",
    tau_scheduler="lambda training_module: 1 * (0.98 ** training_module.current_epoch if training_module.current_epoch < 1000 else 0.01)",
    lr=0.001,
    scheduler="torch.optim.lr_scheduler.ExponentialLR",
    scheduler_interval="epoch",
    scheduler_args={"gamma": 0.99},
)


In [19]:
trainer.fit(tm, dm)


[bnlearn] >Import <asia>
[bnlearn] >Loading bif file </home/hamidreza/.local/lib/python3.8/site-packages/bnlearn/data/asia.bif>
[bnlearn] >Check whether CPDs sum up to one.
[bnlearn] >Check whether CPDs associated with the nodes are consistent: True



  | Name  | Type                   | Params
-------------------------------------------------
0 | model | SinkhornOrderDiscovery | 677 K 
-------------------------------------------------
677 K     Trainable params
0         Non-trainable params
677 K     Total params
2.711     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

RuntimeError: mat1 and mat2 shapes cannot be multiplied (256x8 and 16x1024)

In [12]:
import numpy as np
len(in_features)

37

In [None]:
tm.temp_batch


In [None]:
in_features


In [None]:
batch = tm.batch


In [None]:
# undo one-hot encoding
original_batch = torch.argmax(batch.reshape(batch.shape[0], -1, 2), dim=-1)


In [None]:
original_batch


In [None]:
tm = tm.to("cpu")
loss = tm.criterion(batch=batch.to("cpu"), original_batch=original_batch.to("cpu"), training_module=tm)


In [None]:
tm.model.made(batch.to("cpu"), tm.model.Gamma)


In [None]:
from ocd.models.sinkhorn import sinkhorn

tm.model.made(batch.to("cpu"), sinkhorn(tm.model.Gamma, 0.01, 100))


In [None]:
logits, p = OrderedLinear.forward(tm.model.made.density_estimator, *tm.model.made.layers((batch.to("cpu"), P)))


In [None]:
logits.exp().isnan().any()


In [None]:
from ocd.models.layers.ordered_linear import OrderedLinear


In [None]:
order = sinkhorn(tm.model.Gamma, 0.01, 100).argmax(-1)

# create a permuation matrix for the order
P = torch.zeros_like(tm.model.Gamma)
P[torch.arange(P.shape[0]), order] = 1


In [None]:
for param in tm.parameters():
    # check if there is any nan
    if torch.isnan(param).any():
        print(param)
    # check if there is any inf
    if torch.isinf(param).any():
        print(param)


In [None]:
shit = sinkhorn(tm.model.Gamma, 0.1, 100)
shit


In [None]:
cumsums = torch.cumsum(torch.cat([torch.zeros(1), torch.arange(1, 11)], dim=0), dim=0).int()


In [None]:
cov_features = torch.arange(1, 11)


In [None]:
mat.norm()


In [None]:
cumsums


In [None]:
mat.shape


In [None]:
mat = torch.arange(1, 56).reshape(1, -1).float()

# [ mat[:, cumsums[i] : cumsums[i + 1]] for i in range(len(cov_features)) ]


In [None]:
mat[:, cumsums[9] : cumsums[10]]


In [None]:
# get the max l-inf norm of the matrix
mat.norm(float("inf"))
