# Try Setting up a New Model

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path

import numpy as np
from PIL import Image
import numpy as np
import os
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor
import torchvision.datasets.utils as dataset_utils

from gendis.datasets import CausalMNIST

import matplotlib.pyplot as plt
import seaborn as sns
import logging
import random
from pathlib import Path

import normflows as nf
import numpy as np
import pytorch_lightning as pl
import torch
import torchvision

from gendis.datasets import CausalMNIST, ClusteredMultiDistrDataModule
from gendis.encoder import CausalMultiscaleFlow
from gendis.model import NeuralClusteredASCMFlow
from gendis.normalizing_flow.distribution import NonparametricClusteredCausalDistribution, ClusteredCausalDistribution


In [3]:
def generate_list(x, n_clusters):
    quotient = x // n_clusters
    remainder = x % n_clusters
    result = [quotient] * (n_clusters - 1)
    result.append(quotient + remainder)
    return result

In [4]:
graph_type = "chain"
adjacency_matrix = np.array([[0, 1, 0], [0, 0, 1], [0, 0, 0]])
latent_dim = len(adjacency_matrix)
results_dir = Path("./results/")
results_dir.mkdir(exist_ok=True, parents=True)

# root = "/home/adam2392/projects/data/"
root = '/Users/adam2392/pytorch_data/'
# print(args)
# root = args.root_dir
seed = 1234
max_epochs = 10
accelerator = 'cpu'
batch_size = 10
log_dir = './'

devices = 1
n_jobs = 1
num_workers = 2
print("Running with n_jobs:", n_jobs)

# output filename for the results
fname = results_dir / f"{graph_type}-seed={seed}-results.npz"

# set up logging
logger = logging.getLogger()
logging.basicConfig(level=logging.INFO)
logging.info(f"\n\n\tsaving to {fname} \n")

# set seed
np.random.seed(seed)
random.seed(seed)
pl.seed_everything(seed, workers=True)

# set up transforms for each image to augment the dataset
transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        nf.utils.Scale(255.0 / 256.0),  # normalize the pixel values
        nf.utils.Jitter(1 / 256.0),    # apply random generation
        torchvision.transforms.RandomRotation(350),  # get random rotations
    ]
)

# load dataset
datasets = []
intervention_targets_per_distr = []
hard_interventions_per_distr = None
num_distrs = 0
for intervention_idx in [None, 1, 2, 3]:
    dataset = CausalMNIST(
        root=root,
        graph_type=graph_type,
        label=0,
        download=True,
        train=True,
        n_jobs=None,
        intervention_idx=intervention_idx,
        transform=transform,
    )
    dataset.prepare_dataset(overwrite=False)
    datasets.append(dataset)
    num_distrs += 1
    intervention_targets_per_distr.append(dataset.intervention_targets)

# now we can wrap this in a pytorch lightning datamodule
data_module = ClusteredMultiDistrDataModule(
    datasets=datasets,
    num_workers=num_workers,
    batch_size=batch_size,
    intervention_targets_per_distr=intervention_targets_per_distr,
    log_dir=log_dir,
    flatten=False,
)
data_module.setup()

INFO:root:

	saving to results/chain-seed=1234-results.npz 

Global seed set to 1234


Running with n_jobs: 1
Loading dataset from "/Users/adam2392/pytorch_data/CausalMNIST/chain/chain-0-train.pt"
torch.Size([3, 28, 28]) {'width': tensor([0.5344]), 'color': tensor([0.3977]), 'fracture_thickness': tensor([8.9378]), 'fracture_num_fractures': tensor([1.]), 'label': 0, 'intervention_targets': [0, 0, 0]}
Loading dataset from "/Users/adam2392/pytorch_data/CausalMNIST/chain/chain-1-train.pt"
torch.Size([3, 28, 28]) {'width': tensor([4.0281]), 'color': tensor([0.3975]), 'fracture_thickness': tensor([9.4378]), 'fracture_num_fractures': tensor([1.]), 'label': 0, 'intervention_targets': [1, 0, 0]}
Loading dataset from "/Users/adam2392/pytorch_data/CausalMNIST/chain/chain-2-train.pt"
torch.Size([3, 28, 28]) {'width': tensor([0.4431]), 'color': tensor([0.4416]), 'fracture_thickness': tensor([10.2519]), 'fracture_num_fractures': tensor([0.]), 'label': 0, 'intervention_targets': [0, 0, 1]}
Loading dataset from "/Users/adam2392/pytorch_data/CausalMNIST/chain/chain-3-train.pt"
torch.Size

In [5]:
for idx, img in enumerate(data_module.train_dataset):
    
    print(idx, len(img))
    print(img[0].shape)
    break

0 8
torch.Size([3, 28, 28])


In [6]:
print('done')

done


In [7]:
n_flows = 3  # number of flows to use in nonlinear ICA model
lr_scheduler = None
lr_min = 0.0
lr = 1e-6

# Define the model
net_hidden_dim = 128
net_hidden_dim_cbn = 128
net_hidden_layers = 3
net_hidden_layers_cbn = 3
fix_mechanisms = False

graph = adjacency_matrix
cluster_sizes = generate_list(784 * 3, 3)

# 01: Define the causal base distribution with the graph
causalq0 = NonparametricClusteredCausalDistribution(
    adjacency_matrix=graph,
    cluster_sizes=cluster_sizes,
    intervention_targets_per_distr=intervention_targets_per_distr,
    hard_interventions_per_distr=hard_interventions_per_distr,
    fix_mechanisms=fix_mechanisms,
    n_flows=n_flows,
    n_hidden_dim=net_hidden_dim,
    n_layers=net_hidden_layers,
)

causalq0 = ClusteredCausalDistribution(
    adjacency_matrix=graph,
    cluster_sizes=cluster_sizes,
    intervention_targets_per_distr=torch.Tensor(intervention_targets_per_distr),
    hard_interventions_per_distr=hard_interventions_per_distr,
    fix_mechanisms=fix_mechanisms,
)

In [8]:
input_shape = (3, 28, 28)
channels = 3

# Define flows
L = 2
K = 3
n_dims = np.prod(input_shape)
hidden_channels = 256
split_mode = 'channel'
scale = True

stride_factor = 2

# Set up flows, distributions and merge operations
merges = []
flows = []
for i in range(L):
    flows_ = []
    for j in range(K):
        n_chs = channels * 2 ** (L + 1 - i)
        flows_ += [nf.flows.GlowBlock(n_chs, hidden_channels,
                                     split_mode=split_mode, scale=scale)]
    flows_ += [nf.flows.Squeeze()]
    flows += [flows_]
    if i > 0:
        merges += [nf.flows.Merge()]
        latent_shape = (input_shape[0] * stride_factor ** (L - i), input_shape[1] // stride_factor ** (L - i), 
                        input_shape[2] // stride_factor ** (L - i))
    else:
        latent_shape = (input_shape[0] * stride_factor ** (L + 1), input_shape[1] // stride_factor ** L, 
                        input_shape[2] //stride_factor ** L)
    print(n_chs, np.prod(latent_shape), latent_shape)


# 03: Define the final normalizing flow model
# Construct flow model with the multiscale architecture
encoder = CausalMultiscaleFlow(causalq0, flows, merges)

24 1176 (24, 7, 7)
12 1176 (6, 14, 14)


LU, pivots = torch.lu(A, compute_pivots)
should be replaced with
LU, pivots = torch.linalg.lu_factor(A, compute_pivots)
and
LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True)
should be replaced with
LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots) (Triggered internally at /Users/runner/miniforge3/conda-bld/libtorch_1719361051023/work/aten/src/ATen/native/BatchLinearAlgebra.cpp:1994.)
  LU, pivots, infos = torch._lu_with_info(


In [9]:
def print_num_params(model):
    num_params = sum([np.prod(p.shape) for p in model.parameters()])
    print("Number of parameters: {:,}".format(num_params))

print_num_params(encoder)

Number of parameters: 797,024


In [10]:
# run a test to make sure this actually works
rand_img = torch.arange(28*28*3, dtype=torch.float32).view(1, 3, 28, 28)
out = encoder.forward(rand_img)
print(rand_img.dtype)
print(rand_img.shape, out.shape)
print(encoder.inverse_and_log_det(rand_img)[0].shape)

torch.float32
torch.Size([1, 3, 28, 28]) torch.Size([1])
torch.Size([1, 2352])


In [17]:
max_epochs = 2
# 04a: Define now the full pytorch lightning model
model = NeuralClusteredASCMFlow(
    encoder=encoder,
    lr=lr,
    lr_scheduler=lr_scheduler,
    lr_min=lr_min,
)

# 04b: Define the trainer for the model
checkpoint_root_dir = f"{graph_type}-seed={seed}"
checkpoint_dir = Path(checkpoint_root_dir)
checkpoint_dir.mkdir(exist_ok=True, parents=True)
logger = None
wandb = False
check_val_every_n_epoch = 1
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath=checkpoint_dir,
    save_top_k=5,
    monitor="train_loss",
    every_n_epochs=check_val_every_n_epoch,
)

# Train the model
trainer = pl.Trainer(
    max_epochs=max_epochs,
    logger=logger,
    devices=devices,
    callbacks=[checkpoint_callback],
    check_val_every_n_epoch=check_val_every_n_epoch,
    accelerator=accelerator,
)

# 05: Fit the model and save the data
trainer.fit(
    model,
    datamodule=data_module,
)

# save the final model
torch.save(model, checkpoint_dir / "model.pt")


GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type                 | Params
-------------------------------------------------
0 | encoder | CausalMultiscaleFlow | 797 K 
-------------------------------------------------
789 K     Trainable params
7.8 K     Non-trainable params
797 K     Total params
3.188     Total estimated model params size (MB)


torch.Size([23692, 3, 28, 28]) 6 torch.Size([23692]) torch.Size([23692]) torch.Size([23692])
Sanity Checking DataLoader 0:   0%|                                                                   | 0/2 [00:00<?, ?it/s]Hard intervention...  torch.Size([10, 3])
torch.Size([3, 3])
torch.Size([2, 3])
torch.Size([4, 3])
torch.Size([1, 3])
Hard intervention...  torch.Size([10, 3])
Sanity Checking DataLoader 0:  50%|█████████████████████████████▌                             | 1/2 [00:00<00:00, 16.93it/s]Hard intervention...  torch.Size([10, 3])
torch.Size([3, 3])
torch.Size([4, 3])
torch.Size([3, 3])
Hard intervention...  torch.Size([10, 3])
                                                                                                                            

  rank_zero_warn(


Epoch 0:   0%|                                                                                     | 0/2145 [00:00<?, ?it/s]Hard intervention...  torch.Size([10, 3])
torch.Size([3, 3])
torch.Size([1, 3])
torch.Size([4, 3])
torch.Size([2, 3])
Epoch 0:   0%|                                                              | 1/2145 [00:01<52:41,  1.47s/it, loss=3.41e+05]Hard intervention...  torch.Size([10, 3])
torch.Size([1, 3])
torch.Size([3, 3])
torch.Size([4, 3])
torch.Size([2, 3])
Epoch 0:   0%|                                                              | 2/2145 [00:01<29:45,  1.20it/s, loss=3.33e+05]Hard intervention...  torch.Size([10, 3])
torch.Size([3, 3])
torch.Size([2, 3])
torch.Size([3, 3])
torch.Size([2, 3])
Epoch 0:   0%|                                                              | 3/2145 [00:01<22:14,  1.61it/s, loss=3.34e+05]Hard intervention...  torch.Size([10, 3])
torch.Size([2, 3])
torch.Size([3, 3])
torch.Size([1, 3])
torch.Size([4, 3])
Epoch 0:   0%|                  

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


                                                                                                                            

AttributeError: Can't pickle local object 'EvaluationLoop.advance.<locals>.batch_to_device'

In [111]:
# save the final model
torch.save(model, checkpoint_dir / "model.pt")

PicklingError: Can't pickle <class 'gendis.model.NeuralClusteredASCMFlow'>: it's not the same object as gendis.model.NeuralClusteredASCMFlow

## Let's sample from the model

In [None]:
[(batch_size, n_chs, width, height), (batch_size, n_chs, width, height), ..., repeat for n_layers]

## RealNVP with Pytorch

In [99]:
import os
import torch
import numpy as np
import pandas as pd

from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor, Lambda
import matplotlib.pyplot as plt

In [None]:
input_shape = (3, 28, 28)
channels = 3

# Define flows
L = 2
K = 3
n_dims = np.prod(input_shape)
hidden_channels = 256
split_mode = 'channel'
scale = True

stride_factor = 2

# Set up flows, distributions and merge operations
merges = []
flows = []
for i in range(L):
    flows_ = []
    for j in range(K):
        n_chs = channels * 2 ** (L + 1 - i)
        flows_ += [nf.flows.GlowBlock(n_chs, hidden_channels,
                                     split_mode=split_mode, scale=scale)]
    flows_ += [nf.flows.Squeeze()]
    flows += [flows_]
    if i > 0:
        merges += [nf.flows.Merge()]
        latent_shape = (input_shape[0] * stride_factor ** (L - i), input_shape[1] // stride_factor ** (L - i), 
                        input_shape[2] // stride_factor ** (L - i))
    else:
        latent_shape = (input_shape[0] * stride_factor ** (L + 1), input_shape[1] // stride_factor ** L, 
                        input_shape[2] //stride_factor ** L)
    print(n_chs, np.prod(latent_shape), latent_shape)


# 03: Define the final normalizing flow model
# Construct flow model with the multiscale architecture
encoder = CausalMultiscaleFlow(causalq0, flows, merges)