# Try Setting up a New Model

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path

import numpy as np
from PIL import Image
import numpy as np
import os
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor
import torchvision.datasets.utils as dataset_utils

from gendis.datasets import CausalMNIST

import matplotlib.pyplot as plt
import seaborn as sns
import logging
import random
from pathlib import Path

import normflows as nf
import numpy as np
import pytorch_lightning as pl
import torch
import torchvision

from gendis.datasets import CausalMNIST, ClusteredMultiDistrDataModule
from gendis.encoder import NonparametricClusteredCausalEncoder
from gendis.model import NeuralClusteredASCMFlow
from gendis.normalizing_flow.distribution import NonparametricClusteredCausalDistribution


In [3]:
def generate_list(x, n_clusters):
    quotient = x // n_clusters
    remainder = x % n_clusters
    result = [quotient] * (n_clusters - 1)
    result.append(quotient + remainder)
    return result

In [4]:
graph_type = "chain"
adjacency_matrix = np.array([[0, 1, 0], [0, 0, 1], [0, 0, 0]])
latent_dim = len(adjacency_matrix)
results_dir = Path("./results/")
results_dir.mkdir(exist_ok=True, parents=True)

# root = "/home/adam2392/projects/data/"
root = '/Users/adam2392/pytorch_data/'
# print(args)
# root = args.root_dir
seed = 1234
max_epochs = 10
accelerator = 'cpu'
batch_size = 10
log_dir = './'

devices = 1
n_jobs = 1
num_workers = 2
print("Running with n_jobs:", n_jobs)

# output filename for the results
fname = results_dir / f"{graph_type}-seed={seed}-results.npz"

# set up logging
logger = logging.getLogger()
logging.basicConfig(level=logging.INFO)
logging.info(f"\n\n\tsaving to {fname} \n")

# set seed
np.random.seed(seed)
random.seed(seed)
pl.seed_everything(seed, workers=True)

# set up transforms for each image to augment the dataset
transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        nf.utils.Scale(255.0 / 256.0),
        nf.utils.Jitter(1 / 256.0),
        torchvision.transforms.RandomRotation(350),
    ]
)

# load dataset
datasets = []
intervention_targets_per_distr = []
hard_interventions_per_distr = None
for intervention_idx in [None, 1, 2, 3]:
    dataset = CausalMNIST(
        root=root,
        graph_type=graph_type,
        label=0,
        download=True,
        train=True,
        n_jobs=None,
        intervention_idx=intervention_idx,
        transform=transform,
    )
    dataset.prepare_dataset(overwrite=False)
    datasets.append(dataset)

    intervention_targets_per_distr.append(dataset.intervention_targets)

# now we can wrap this in a pytorch lightning datamodule
data_module = ClusteredMultiDistrDataModule(
    datasets=datasets,
    num_workers=num_workers,
    batch_size=batch_size,
    intervention_targets_per_distr=intervention_targets_per_distr,
    log_dir=log_dir,
    flatten=True,
)
data_module.setup()

INFO:root:

	saving to results/chain-seed=1234-results.npz 

Global seed set to 1234


Running with n_jobs: 1
Loading dataset from "/Users/adam2392/pytorch_data/CausalMNIST/chain/chain-0-train.pt"
torch.Size([3, 28, 28]) {'width': tensor([0.5344]), 'color': tensor([0.3977]), 'fracture_thickness': tensor([8.9378]), 'fracture_num_fractures': tensor([1.]), 'label': 0, 'intervention_targets': [0, 0, 0]}
Loading dataset from "/Users/adam2392/pytorch_data/CausalMNIST/chain/chain-1-train.pt"
torch.Size([3, 28, 28]) {'width': tensor([4.0281]), 'color': tensor([0.3975]), 'fracture_thickness': tensor([9.4378]), 'fracture_num_fractures': tensor([1.]), 'label': 0, 'intervention_targets': [1, 0, 0]}
Loading dataset from "/Users/adam2392/pytorch_data/CausalMNIST/chain/chain-2-train.pt"
torch.Size([3, 28, 28]) {'width': tensor([0.4431]), 'color': tensor([0.4416]), 'fracture_thickness': tensor([10.2519]), 'fracture_num_fractures': tensor([0.]), 'label': 0, 'intervention_targets': [0, 0, 1]}
Loading dataset from "/Users/adam2392/pytorch_data/CausalMNIST/chain/chain-3-train.pt"
torch.Size

In [5]:
print('done')

done


In [6]:
n_flows = 3  # number of flows to use in nonlinear ICA model
lr_scheduler = None
lr_min = 0.0
lr = 1e-6

# Define the model
net_hidden_dim = 128
net_hidden_dim_cbn = 128
net_hidden_layers = 3
net_hidden_layers_cbn = 3
fix_mechanisms = False

graph = adjacency_matrix
cluster_sizes = generate_list(784 * 3, 3)

# 01: Define the causal base distribution with the graph
q0 = NonparametricClusteredCausalDistribution(
    adjacency_matrix=graph,
    cluster_sizes=cluster_sizes,
    intervention_targets_per_distr=intervention_targets_per_distr,
    hard_interventions_per_distr=hard_interventions_per_distr,
    fix_mechanisms=fix_mechanisms,
    n_flows=n_flows,
    n_hidden_dim=net_hidden_dim,
    n_layers=net_hidden_layers,
)

In [8]:
print('wtf')
# 02: Define the flow layers
num_flow_layers = 10
num_flow_blocks = 5
hidden_channels = 256
channels = 3
split_mode = "channel"
scale = True
flows = []
for i in range(num_flow_layers):
    # Neural network with two hidden layers having 64 units each
    # Last layer is initialized by zeros making training more stable
    param_map = nf.nets.MLP([1, 64, 2], init_zeros=True)
    # Add flow layer
    flows.append(nf.flows.AffineCouplingBlock(param_map))
    # Swap dimensions
    flows.append(nf.flows.Permute(2, mode="swap"))

    for j in range(num_flow_blocks):
        flows.append(
            nf.flows.GlowBlock(
                channels * 2 ** (num_flow_layers + 1 - i),
                hidden_channels,
                split_mode=split_mode,
                scale=scale,
            )
        )

    flows.append(nf.flows.Squeeze())

wtf


LU, pivots = torch.lu(A, compute_pivots)
should be replaced with
LU, pivots = torch.linalg.lu_factor(A, compute_pivots)
and
LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True)
should be replaced with
LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots) (Triggered internally at /Users/runner/miniforge3/conda-bld/libtorch_1719361051023/work/aten/src/ATen/native/BatchLinearAlgebra.cpp:1994.)
  LU, pivots, infos = torch._lu_with_info(


In [10]:
print('done')

done


In [20]:
# 03: Define the final normalizing flow model
encoder = NonparametricClusteredCausalEncoder(
    q0=q0,
    graph=graph,
    cluster_sizes=cluster_sizes,
    intervention_targets_per_distr=intervention_targets_per_distr,
    hard_interventions_per_distr=hard_interventions_per_distr,
    fix_mechanisms=fix_mechanisms,
    flows=flows,
)
print(len(list(encoder.parameters())))

638


In [26]:
rand_img = torch.arange(28*28*3, dtype=torch.float32).view(1, 3, 28, 28)
out = encoder(rand_img)
print(rand_img.dtype)
print(rand_img.shape, out.shape)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (84x14 and 1x64)

In [None]:

# 04a: Define now the full pytorch lightning model
model = NeuralClusteredASCMFlow(
    encoder=encoder,
    cluster_sizes=generate_list(784 * 3, 3),
    graph=adjacency_matrix,
    lr=lr,
    lr_scheduler=lr_scheduler,
    lr_min=lr_min,
)

# 04b: Define the trainer for the model
checkpoint_root_dir = f"{graph_type}-seed={seed}"
checkpoint_dir = Path(checkpoint_root_dir)
checkpoint_dir.mkdir(exist_ok=True, parents=True)
logger = None
wandb = False
check_val_every_n_epoch = 1
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath=checkpoint_dir,
    save_top_k=5,
    monitor="train_loss",
    every_n_epochs=check_val_every_n_epoch,
)

# Train the model
trainer = pl.Trainer(
    max_epochs=max_epochs,
    logger=logger,
    devices=devices,
    callbacks=[checkpoint_callback],
    check_val_every_n_epoch=check_val_every_n_epoch,
    accelerator=accelerator,
)

# 05: Fit the model and save the data
trainer.fit(
    model,
    datamodule=data_module,
)

# save the final model
torch.save(model, checkpoint_dir / "model.pt")
