In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("../..")  # add parent directory to path
from lightning_toolbox import DataModule, TrainingModule
from lightning import Trainer, seed_everything
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
from lightning.pytorch.callbacks import LearningRateMonitor
from ocd.models.permutation.utils import all_permutations, abbriviate_permutation
from ocd.data import SyntheticOCDDataset
from matplotlib import pyplot as plt
import copy
import random
import yaml
import wandb
# load and setup base configs
data_config = yaml.load(open('config/data.yaml'), Loader=yaml.FullLoader)['init_args']
model_config = yaml.load(open('config/model-simplified.yaml'), Loader=yaml.FullLoader)['init_args']
training_config = yaml.load(open('config/trainer.yaml'), Loader=yaml.FullLoader)

In [None]:
# utility functions
def set_model_config(n, model_config, num_transforms=2, num_layers=2, ordering=None):
    model_config = copy.deepcopy(model_config)
    if ordering is None:
        ordering = list(range(n))
    model_config["model_args"]["num_transforms"] = num_transforms
    model_config["model_args"]["layers"] = [n] * num_layers
    model_config["model_args"]["ordering"] = ordering
    model_config["model_args"]["in_features"] = n
    return model_config


def set_data_config(n, data_config, true_perm):
    data_config = copy.deepcopy(data_config)
    data_config["dataset_args"]["scm_generator_args"]["graph_generator_args"]["enforce_ordering"] = true_perm
    data_config["dataset_args"]["scm_generator_args"]["graph_generator_args"]["n"] = n
    return data_config


def describe_data(dataset, true_perm):
    dataset.scm.draw()
    data = dataset.samples.to_numpy()
    plt.title(f"SCM ${abbriviate_permutation(true_perm)}$")
    plt.show()
    for i in range(data.shape[1]):
        plt.hist(data[:, i], bins=100, density=True, alpha=0.5, label=f"$x_{i}$")
    plt.title(f"True permutation: ${abbriviate_permutation(true_perm)}$")
    plt.legend()
    plt.show()
    return dataset


def setup_datamodule(data_config):
    dataset = SyntheticOCDDataset(**data_config["dataset_args"])
    true_perm = data_config["dataset_args"]["scm_generator_args"]["graph_generator_args"]["enforce_ordering"]
    describe_data(dataset, true_perm)
    return DataModule(**{**data_config, "dataset": dataset})


def setup_logger_tensorboard(n, perm, true_perm, base_name):
    version = abbriviate_permutation(perm)
    return TensorBoardLogger(
        save_dir="lightning_logs", name=f"{base_name}-{n}", version=version if perm != true_perm else version + "-true"
    )


def setup_logger_wandb(n, perm, true_perm, base_name):
    version = abbriviate_permutation(perm)
    name = f"{base_name}-{n}/{version if perm != true_perm else version + '-true'}"
    return WandbLogger(project=f"fixed-simplified", name=name)


setup_logger = setup_logger_wandb


def test_fixed_permutations(n, data_config, model_config, training_config, seed=666, base_name="fixed"):
    all_perms = all_permutations(n)
    true_perm = all_perms[random.randint(0, n - 1)]  # pick a random permutation to be the true one
    dm = setup_datamodule(set_data_config(n, data_config, true_perm))
    for i, perm in enumerate(all_perms):
        seed_everything(seed)
        tm = TrainingModule(**set_model_config(n, model_config, ordering=perm))
        print(f"Training with permutation {perm}, true ordering is {true_perm}")
        if i == 0:
            print(tm)
        trainer = Trainer(
            **training_config,
            logger=setup_logger(n, perm, true_perm, base_name),
            callbacks=[LearningRateMonitor(logging_interval="step")],
        )
        trainer.fit(tm, dm)
        wandb.finish()


In [None]:
# test the simplified model with chain graphs
data_config["dataset_args"]["scm_generator_args"]["graph_generator_args"]['graph_type'] = 'chain'
model_config = yaml.load(open('config/model-simplified.yaml'), Loader=yaml.FullLoader)['init_args']
for i in range(2, 5):
    test_fixed_permutations(i, data_config, model_config, training_config, base_name='chain')