# Try Setting up a New Model

In [1]:
%load_ext autoreload
%autoreload 2

In [110]:
import os
from pathlib import Path

import numpy as np
from PIL import Image
import numpy as np
import os
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor
import torchvision.datasets.utils as dataset_utils

from gendis.datasets import CausalMNIST

import matplotlib.pyplot as plt
import seaborn as sns
import logging
import random
from pathlib import Path

import normflows as nf
import numpy as np
import pytorch_lightning as pl
import torch
import torchvision

from gendis.datasets import CausalMNIST, ClusteredMultiDistrDataModule
from gendis.encoder import NonparametricClusteredCausalEncoder
from gendis.model import NeuralClusteredASCMFlow
from gendis.base import CausalMultiscaleFlow
from gendis.normalizing_flow.distribution import NonparametricClusteredCausalDistribution


In [10]:
def generate_list(x, n_clusters):
    quotient = x // n_clusters
    remainder = x % n_clusters
    result = [quotient] * (n_clusters - 1)
    result.append(quotient + remainder)
    return result

In [94]:
graph_type = "chain"
adjacency_matrix = np.array([[0, 1, 0], [0, 0, 1], [0, 0, 0]])
latent_dim = len(adjacency_matrix)
results_dir = Path("./results/")
results_dir.mkdir(exist_ok=True, parents=True)

# root = "/home/adam2392/projects/data/"
root = '/Users/adam2392/pytorch_data/'
# print(args)
# root = args.root_dir
seed = 1234
max_epochs = 10
accelerator = 'cpu'
batch_size = 10
log_dir = './'

devices = 1
n_jobs = 1
num_workers = 2
print("Running with n_jobs:", n_jobs)

# output filename for the results
fname = results_dir / f"{graph_type}-seed={seed}-results.npz"

# set up logging
logger = logging.getLogger()
logging.basicConfig(level=logging.INFO)
logging.info(f"\n\n\tsaving to {fname} \n")

# set seed
np.random.seed(seed)
random.seed(seed)
pl.seed_everything(seed, workers=True)

# set up transforms for each image to augment the dataset
transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        nf.utils.Scale(255.0 / 256.0),  # normalize the pixel values
        nf.utils.Jitter(1 / 256.0),    # apply random generation
        torchvision.transforms.RandomRotation(350),  # get random rotations
    ]
)

# load dataset
datasets = []
intervention_targets_per_distr = []
hard_interventions_per_distr = None
num_distrs = 0
for intervention_idx in [None, 1, 2, 3]:
    dataset = CausalMNIST(
        root=root,
        graph_type=graph_type,
        label=0,
        download=True,
        train=True,
        n_jobs=None,
        intervention_idx=intervention_idx,
        transform=transform,
    )
    dataset.prepare_dataset(overwrite=False)
    datasets.append(dataset)
    num_distrs += 1
    intervention_targets_per_distr.append(dataset.intervention_targets)

# now we can wrap this in a pytorch lightning datamodule
data_module = ClusteredMultiDistrDataModule(
    datasets=datasets,
    num_workers=num_workers,
    batch_size=batch_size,
    intervention_targets_per_distr=intervention_targets_per_distr,
    log_dir=log_dir,
    flatten=False,
)
data_module.setup()

INFO:root:

	saving to results/chain-seed=1234-results.npz 

Global seed set to 1234


Running with n_jobs: 1
Loading dataset from "/Users/adam2392/pytorch_data/CausalMNIST/chain/chain-0-train.pt"
torch.Size([3, 28, 28]) {'width': tensor([0.5344]), 'color': tensor([0.3977]), 'fracture_thickness': tensor([8.9378]), 'fracture_num_fractures': tensor([1.]), 'label': 0, 'intervention_targets': [0, 0, 0]}
Loading dataset from "/Users/adam2392/pytorch_data/CausalMNIST/chain/chain-1-train.pt"
torch.Size([3, 28, 28]) {'width': tensor([4.0281]), 'color': tensor([0.3975]), 'fracture_thickness': tensor([9.4378]), 'fracture_num_fractures': tensor([1.]), 'label': 0, 'intervention_targets': [1, 0, 0]}
Loading dataset from "/Users/adam2392/pytorch_data/CausalMNIST/chain/chain-2-train.pt"
torch.Size([3, 28, 28]) {'width': tensor([0.4431]), 'color': tensor([0.4416]), 'fracture_thickness': tensor([10.2519]), 'fracture_num_fractures': tensor([0.]), 'label': 0, 'intervention_targets': [0, 0, 1]}
Loading dataset from "/Users/adam2392/pytorch_data/CausalMNIST/chain/chain-3-train.pt"
torch.Size

In [87]:
for idx, img in enumerate(data_module.train_dataset):
    
    print(idx, len(img))
    print(img[0].shape)
    break

0 8
torch.Size([3, 28, 28])


In [12]:
print('done')

done


In [104]:
n_flows = 3  # number of flows to use in nonlinear ICA model
lr_scheduler = None
lr_min = 0.0
lr = 1e-6

# Define the model
net_hidden_dim = 128
net_hidden_dim_cbn = 128
net_hidden_layers = 3
net_hidden_layers_cbn = 3
fix_mechanisms = False

graph = adjacency_matrix
cluster_sizes = generate_list(784 * 3, 3)

# 01: Define the causal base distribution with the graph
causalq0 = NonparametricClusteredCausalDistribution(
    adjacency_matrix=graph,
    cluster_sizes=cluster_sizes,
    intervention_targets_per_distr=intervention_targets_per_distr,
    hard_interventions_per_distr=hard_interventions_per_distr,
    fix_mechanisms=fix_mechanisms,
    n_flows=n_flows,
    n_hidden_dim=net_hidden_dim,
    n_layers=net_hidden_layers,
)

In [105]:
input_shape = (3, 28, 28)
channels = 3

# Define flows
L = 2
K = 3
n_dims = np.prod(input_shape)
hidden_channels = 256
split_mode = 'channel'
scale = True

stride_factor = 2

# Set up flows, distributions and merge operations
merges = []
flows = []
for i in range(L):
    flows_ = []
    for j in range(K):
        n_chs = channels * 2 ** (L + 1 - i)
        flows_ += [nf.flows.GlowBlock(n_chs, hidden_channels,
                                     split_mode=split_mode, scale=scale)]
    flows_ += [nf.flows.Squeeze()]
    flows += [flows_]
    if i > 0:
        merges += [nf.flows.Merge()]
        latent_shape = (input_shape[0] * stride_factor ** (L - i), input_shape[1] // stride_factor ** (L - i), 
                        input_shape[2] // stride_factor ** (L - i))
    else:
        latent_shape = (input_shape[0] * stride_factor ** (L + 1), input_shape[1] // stride_factor ** L, 
                        input_shape[2] //stride_factor ** L)
    print(n_chs, np.prod(latent_shape), latent_shape)


# 03: Define the final normalizing flow model
# Construct flow model with the multiscale architecture
encoder = CausalMultiscaleFlow(causalq0, flows, merges)

24 1176 (24, 7, 7)
12 1176 (6, 14, 14)


In [106]:
def print_num_params(model):
    num_params = sum([np.prod(p.shape) for p in model.parameters()])
    print("Number of parameters: {:,}".format(num_params))

print_num_params(encoder)

Number of parameters: 22,910,208


In [107]:
# run a test to make sure this actually works
rand_img = torch.arange(28*28*3, dtype=torch.float32).view(1, 3, 28, 28)
out = encoder.forward(rand_img)
print(rand_img.dtype)
print(rand_img.shape, out.shape)
print(encoder.inverse_and_log_det(rand_img)[0].shape)

torch.Size([1, 6, 14, 14])
torch.Size([1, 1176]) 0 1176
torch.float32
torch.Size([1, 3, 28, 28]) torch.Size([1])
torch.Size([1, 2352])


In [108]:
# 04a: Define now the full pytorch lightning model
model = NeuralClusteredASCMFlow(
    encoder=encoder,
    lr=lr,
    lr_scheduler=lr_scheduler,
    lr_min=lr_min,
)

# 04b: Define the trainer for the model
checkpoint_root_dir = f"{graph_type}-seed={seed}"
checkpoint_dir = Path(checkpoint_root_dir)
checkpoint_dir.mkdir(exist_ok=True, parents=True)
logger = None
wandb = False
check_val_every_n_epoch = 1
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath=checkpoint_dir,
    save_top_k=5,
    monitor="train_loss",
    every_n_epochs=check_val_every_n_epoch,
)

# Train the model
trainer = pl.Trainer(
    max_epochs=max_epochs,
    logger=logger,
    devices=devices,
    callbacks=[checkpoint_callback],
    check_val_every_n_epoch=check_val_every_n_epoch,
    accelerator=accelerator,
)

# 05: Fit the model and save the data
trainer.fit(
    model,
    datamodule=data_module,
)

# save the final model
torch.save(model, checkpoint_dir / "model.pt")


GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Epoch 0:   0%|                                                                                                       | 0/2145 [09:00<?, ?it/s]
Epoch 0:   0%|                                                                                                       | 0/2145 [06:35<?, ?it/s]
Epoch 0:   1%|█                                                                            | 29/2145 [06:07<7:27:00, 12.68s/it, loss=1.84e+06]
Epoch 0:   0%|                                                                                                       | 0/2145 [01:28<?, ?it/s]



  | Name    | Type                 | Params
-------------------------------------------------
0 | encoder | CausalMultiscaleFlow | 22.9 M
-------------------------------------------------
22.9 M    Trainable params
0         Non-trainable params
22.9 M    Total params
91.641    Total estimated model params size (MB)


torch.Size([23692, 3, 28, 28]) 6 torch.Size([23692]) torch.Size([23692]) torch.Size([23692])
Sanity Checking DataLoader 0:   0%|                                                                                     | 0/2 [00:00<?, ?it/s]torch.Size([10, 6, 14, 14])
torch.Size([10, 1176]) 0 1176
torch.Size([10, 6, 14, 14])
torch.Size([10, 1176]) 0 1176
torch.Size([10]) torch.Size([10, 3, 28, 28]) torch.Size([10]) torch.Size([10]) torch.Size([10]) torch.Size([10]) torch.Size([10])
Sanity Checking DataLoader 0:  50%|██████████████████████████████████████▌                                      | 1/2 [00:00<00:00,  3.06it/s]torch.Size([10, 6, 14, 14])
torch.Size([10, 1176]) 0 1176
torch.Size([10, 6, 14, 14])
torch.Size([10, 1176]) 0 1176
torch.Size([10]) torch.Size([10, 3, 28, 28]) torch.Size([10]) torch.Size([10]) torch.Size([10]) torch.Size([10]) torch.Size([10])
Epoch 0:   0%|                                                                                                       | 0/2145 [00:0

PicklingError: Can't pickle <class 'gendis.model.NeuralClusteredASCMFlow'>: it's not the same object as gendis.model.NeuralClusteredASCMFlow

In [111]:
# save the final model
torch.save(model, checkpoint_dir / "model.pt")

PicklingError: Can't pickle <class 'gendis.model.NeuralClusteredASCMFlow'>: it's not the same object as gendis.model.NeuralClusteredASCMFlow

## Let's sample from the model

In [None]:
[(batch_size, n_chs, width, height), (batch_size, n_chs, width, height), ..., repeat for n_layers]

## RealNVP with Pytorch

In [99]:
import os
import torch
import numpy as np
import pandas as pd

from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor, Lambda
import matplotlib.pyplot as plt

In [100]:
device='cpu'

In [101]:
# (Adapted) Code from PyTorch's Resnet impl: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py

def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
    """3x3 convolution with padding"""
    return nn.Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=dilation,
        groups=groups,
        bias=False,
        dilation=dilation,
    )


def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion: int = 1

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        norm_layer = None,
    ) -> None:
        super().__init__()
        if norm_layer is None:
            # norm_layer = nn.BatchNorm2d
            norm_layer = nn.InstanceNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError("BasicBlock only supports groups=1 and base_width=64")
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

    
class Bottleneck(nn.Module):
    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
    # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
    # This variant is also known as ResNet V1.5 and improves accuracy according to
    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.

    expansion: int = 1

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        norm_layer = None,
    ) -> None:
        super().__init__()
        if norm_layer is None:
            #norm_layer = nn.BatchNorm2d
            norm_layer = nn.InstanceNorm2d
        width = int(planes * (base_width / 64.0)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)
        
        out += identity
        out = self.relu(out)

        return out

In [102]:
class MyBatchNorm2d(nn.modules.batchnorm._NormBase):
    ''' Partially based on: 
        https://pytorch.org/docs/stable/_modules/torch/nn/modules/batchnorm.html#BatchNorm2d
        https://discuss.pytorch.org/t/implementing-batchnorm-in-pytorch-problem-with-updating-self-running-mean-and-self-running-var/49314/5 
    '''
    def __init__(
        self,
        num_features,
        eps=1e-5,
        momentum=0.005,
        device=None,
        dtype=None
    ):
        factory_kwargs = {'device': device, 'dtype': dtype, 'affine': False, 'track_running_stats': True}
        super(MyBatchNorm2d, self).__init__(
            num_features, eps, momentum, **factory_kwargs
        )
        
    def _check_input_dim(self, input):
        if input.dim() != 4:
            raise ValueError("expected 4D input (got {}D input)".format(input.dim()))

    def forward(self, input, validation=False):
        self._check_input_dim(input)
    
        if self.training:
            # Note: Need to detatch `running_{mean,var}` so don't backwards propagate through them
            unbiased_var, tmean = torch.var_mean(input, [0, 2, 3], unbiased=True)
            mean = torch.mean(input, [0, 2, 3]) # along channel axis
            unbiased_var = torch.var(input, [0, 2, 3], unbiased=True) # along channel axis
            running_mean = (1.0 - self.momentum) * self.running_mean.detach() + self.momentum * mean
            
            # Strange: PyTorch impl. of running variance uses biased_variance for the batch calc but
            # *unbiased_var* for the running_var!
            # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Normalization.cpp#L190
            running_var = (1.0 - self.momentum) * self.running_var.detach() + self.momentum * unbiased_var
            
            # BK: Modification from the paper to use running mean/var instead of batch mean/var
            # change shape
            current_mean = running_mean.view([1, self.num_features, 1, 1]).expand_as(input)
            current_var = running_var.view([1, self.num_features, 1, 1]).expand_as(input)
            
            denom = (current_var + self.eps)
            y = (input - current_mean) / denom.sqrt()
            
            self.running_mean = running_mean
            self.running_var = running_var
            
            return y, -0.5 * torch.log(denom)
        else:
            current_mean = self.running_mean.view([1, self.num_features, 1, 1]).expand_as(input)
            current_var = self.running_var.view([1, self.num_features, 1, 1]).expand_as(input)
            
            if validation:
                denom = (current_var + self.eps)
                y = (input - current_mean) / denom.sqrt()
            else:
                # Reverse operation for testing
                denom = (current_var + self.eps)
                y = input * denom.sqrt() + current_mean
                
            return y, -0.5 * torch.log(denom)

In [103]:
class Reshape(nn.Module):
    def __init__(self, shape):
        super(Reshape, self).__init__()
        self.shape = tuple([-1] + list(shape))
        
    def forward(self, x):
        return torch.reshape(x, self.shape)

def dense_backbone(shape, network_width):
    input_width = shape[0] * shape[1] * shape[2]
    return nn.Sequential(
        nn.Flatten(),
        nn.Linear(input_width, network_width),
        nn.ReLU(),
        nn.Linear(network_width, input_width),
        Reshape(shape)
    )

def bottleneck_backbone(in_planes, planes):
    return nn.Sequential(
        conv3x3(in_planes, planes),
        BasicBlock(planes, planes),
        BasicBlock(planes, planes),
        conv3x3(planes, in_planes),
    )

check_mask = {}
check_mask_device = {}
def checkerboard_mask(shape, to_device=True):
    global check_mask, check_mask_device
    if shape not in check_mask:
        check_mask[shape] = 1 - np.indices(shape).sum(axis=0) % 2
        check_mask[shape] = torch.Tensor(check_mask[shape])
        
    if to_device and shape not in check_mask_device:
        check_mask_device[shape] = check_mask[shape].to(device)
        
    return check_mask_device[shape] if to_device else check_mask[shape]

chan_mask = {}
chan_mask_device = {}
def channel_mask(shape, to_device=True):
    assert len(shape) == 3, shape
    assert shape[0] % 2 == 0, shape
    global chan_mask, chan_mask_device
    if shape not in chan_mask:
        chan_mask[shape] = torch.cat([torch.zeros((shape[0] // 2, shape[1], shape[2])),
                                      torch.ones((shape[0] // 2, shape[1], shape[2])),],
                                      dim=0)
        assert chan_mask[shape].shape == shape, (chan_mask[shape].shape, shape)
        
    if to_device and shape not in chan_mask_device:
        chan_mask_device[shape] = chan_mask[shape].to(device)
        
    return chan_mask_device[shape] if to_device else chan_mask[shape]

In [123]:
class NormalizingFlowMNist(nn.Module):
    EPSILON = 1e-5
    
    def __init__(self, num_coupling=6, num_final_coupling=4, planes=64):
        super(NormalizingFlowMNist, self).__init__()
        self.num_coupling = num_coupling
        self.num_final_coupling = num_final_coupling
        self.shape = (3, 28, 28)
        
        self.planes = planes
        self.s = nn.ModuleList()
        self.t = nn.ModuleList()
        self.norms = nn.ModuleList()
        
        # Learnable scalar scaling parameters for outputs of S and T
        self.s_scale = nn.ParameterList()
        self.t_scale = nn.ParameterList()
        self.t_bias = nn.ParameterList()
        self.shapes = []
      
        shape = self.shape
        for i in range(num_coupling):
            self.s.append(bottleneck_backbone(shape[0], planes))
            self.t.append(bottleneck_backbone(shape[0], planes))
            
            self.s_scale.append(torch.nn.Parameter(torch.zeros(shape), requires_grad=True))
            self.t_scale.append(torch.nn.Parameter(torch.zeros(shape), requires_grad=True))
            self.t_bias.append(torch.nn.Parameter(torch.zeros(shape), requires_grad=True))
            
            self.norms.append(MyBatchNorm2d(shape[0]))
            
            self.shapes.append(shape)
           
            if i % 6 == 2:
                shape = (4 * shape[0], shape[1] // 2, shape[2] // 2)
                
            if i % 6 == 5:
                # Factoring out half the channels
                shape = (shape[0] // 2, shape[1], shape[2])
                planes = 2 * planes
       
        # Final coupling layers checkerboard
        for i in range(num_final_coupling):
            self.s.append(bottleneck_backbone(shape[0], planes))
            self.t.append(bottleneck_backbone(shape[0], planes))
            
            self.s_scale.append(torch.nn.Parameter(torch.zeros(shape), requires_grad=True))
            self.t_scale.append(torch.nn.Parameter(torch.zeros(shape), requires_grad=True))
            self.t_bias.append(torch.nn.Parameter(torch.zeros(shape), requires_grad=True))
            
            self.norms.append(MyBatchNorm2d(shape[0]))
            
            self.shapes.append(shape)
           
        self.validation = False
    
    def validate(self):
        self.eval()
        self.validation = True
        
    def train(self, mode=True):
        nn.Module.train(self, mode)
        self.validation = False

    def forward(self, x):
        if self.training or self.validation:
            s_vals = []
            norm_vals = []
            y_vals = []
            
            for i in range(self.num_coupling):
                shape = self.shapes[i]
                mask = checkerboard_mask(shape) if i % 6 < 3 else channel_mask(shape)
                mask = mask if i % 2 == 0 else (1 - mask)
               
                t = (self.t_scale[i]) * self.t[i](mask * x) + (self.t_bias[i])
                s = (self.s_scale[i]) * torch.tanh(self.s[i](mask * x))
                y = mask * x + (1 - mask) * (x * torch.exp(s) + t)
                s_vals.append(torch.flatten((1 - mask) * s))
               
                if self.norms[i] is not None:
                    y, norm_loss = self.norms[i](y, validation=self.validation)
                    norm_vals.append(norm_loss)
                    
                if i % 6 == 2:
                    y = torch.nn.functional.pixel_unshuffle(y, 2)
                    
                if i % 6 == 5:
                    factor_channels = y.shape[1] // 2
                    y_vals.append(torch.flatten(y[:, factor_channels:, :, :], 1))
                    y = y[:, :factor_channels, :, :]
                    
                x = y
                
            # Final checkboard coupling
            for i in range(self.num_coupling, self.num_coupling + self.num_final_coupling):
                shape = self.shapes[i]
                mask = checkerboard_mask(shape)
                mask = mask if i % 2 == 0 else (1 - mask)
               
                t = (self.t_scale[i]) * self.t[i](mask * x) + (self.t_bias[i])
                s = (self.s_scale[i]) * torch.tanh(self.s[i](mask * x))
                y = mask * x + (1 - mask) * (x * torch.exp(s) + t)
                s_vals.append(torch.flatten((1 - mask) * s))
                
                if self.norms[i] is not None:
                    y, norm_loss = self.norms[i](y, validation=self.validation)
                    norm_vals.append(norm_loss)
                
                x = y

            y_vals.append(torch.flatten(y, 1))
            
            # Return outputs and vars needed for determinant
            return (torch.flatten(torch.cat(y_vals, 1), 1),
                    torch.cat(s_vals), 
                    torch.cat([torch.flatten(v) for v in norm_vals]) if len(norm_vals) > 0 else torch.zeros(1),
                    torch.cat([torch.flatten(s) for s in self.s_scale]))
        else:
            y = x
            y_remaining = y
           
            layer_vars = np.prod(self.shapes[-1])
            y = torch.reshape(y_remaining[:, -layer_vars:], (-1,) + self.shapes[-1])
            y_remaining = y_remaining[:, :-layer_vars]
            
            # Reversed final checkboard coupling
            for i in reversed(range(self.num_coupling, self.num_coupling + self.num_final_coupling)):
                shape = self.shapes[i]
                mask = checkerboard_mask(shape)
                mask = mask if i % 2 == 0 else (1 - mask)
                
                if self.norms[i] is not None:
                    y, _ = self.norms[i](y)
              
                t = (self.t_scale[i]) * self.t[i](mask * y) + (self.t_bias[i])
                s = (self.s_scale[i]) * torch.tanh(self.s[i](mask * y))
                x = mask * y + (1 - mask) * ((y - t) * torch.exp(-s))
               
                y = x           
          
            layer_vars = np.prod(shape)
            y = torch.cat((y, torch.reshape(y_remaining[:, -layer_vars:], (-1,) + shape)), 1)
            y_remaining = y_remaining[:, :-layer_vars]
            
            # Multi-scale coupling layers
            for i in reversed(range(self.num_coupling)):
                shape = self.shapes[i]
                mask = checkerboard_mask(shape) if i % 6 < 3 else channel_mask(shape)
                mask = mask if i % 2 == 0 else (1 - mask)
              
                if self.norms[i] is not None:
                    y, _ = self.norms[i](y)
                    
                t = (self.t_scale[i]) * self.t[i](mask * y) + (self.t_bias[i])
                s = (self.s_scale[i]) * torch.tanh(self.s[i](mask * y))
                x = mask * y + (1 - mask) * ((y - t) * torch.exp(-s))
               
                if i % 6 == 3:
                    x = torch.nn.functional.pixel_shuffle(x, 2)
                    
                y = x
                
                if i > 0 and i % 6 == 0:
                    layer_vars = np.prod(shape)
                    y = torch.cat((y, torch.reshape(y_remaining[:, -layer_vars:], (-1,) + shape)), 1)
                    y_remaining = y_remaining[:, :-layer_vars]
            
            assert np.prod(y_remaining.shape) == 0
            
            return x

In [124]:
def pre_process(x):
    # Convert back to integer values
    x = x * 255.
    
    # Add random uniform [0, 1] noise to get a proper likelihood estimate
    # https://bjlkeng.github.io/posts/a-note-on-using-log-likelihood-for-generative-models/
    x = x + torch.rand(x.shape)

    # Apply transform to deal with boundary effects (see realNVP paper)
    #x = torch.logit(0.05 + 0.90 * x / 256)
    #return x
    return x / 255

def post_process(x):
    # Convert back to integer values
    #return torch.clip(torch.floor(256 / 0.90 * (torch.sigmoid(x) - 0.05)), min=0, max=255) / 255
    return torch.clip(x, min=0, max=1)

In [125]:
model = NormalizingFlowMNist(num_coupling=12, num_final_coupling=4, planes=64).to('cpu')

In [139]:
print(rand_img[0, ...].shape)
model(rand_img)

torch.Size([3, 32, 32])


torch.Size([1])

In [120]:
train_dataset = datasets.MNIST('data', train=True, download=True,
                               transform=transforms.Compose([
                                   transforms.ToTensor(),
                               ]))

batch_size = 5
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [122]:
for batch, (X, _) in enumerate(train_loader):
    # Transfer to GPU
    X = pre_process(X)
    X = X.to(device)

    # Compute prediction and loss
    print(X.shape)
    y, s, norms, scale = model(X)
    
    break

torch.Size([5, 1, 28, 28])
