In [16]:
import numpy as np
from pathlib import Path
import json

# DATASET = "sequence_2dcubes_grayscale_shapes_0.05_child_noise_25k_skip_render_with_noise_encodings"
# DATASET = "1dcubes_0.16_child_noise_skip_render"
# DATASET = "1dcubes_equi_child_noise_skip_render"
# DATASET = "1dcubes_0.5_child_noise_skip_render"
# DATASET = "1dcubes_0.16_child_noise_gaussian_skip_render"
# DATASET = "1dcubes_equi_child_noise_gaussian_skip_render"
# DATASET = "1dcubes_0.5_child_noise_gaussian_skip_render"
DATASET = ("2dcubes_equiscale_0.05_child_noise")
# DATASET = ("2dcubes_equiscale_equi_child_noise")

DATASET_DIR = Path("outputs/compressed") / DATASET

ROTATE = True

OUT_DIR = Path("outputs/compressed") / (
    DATASET + f"_causal_variables{'_unrotated' if not ROTATE else ''}"
)
OUT_DIR.mkdir(parents=True, exist_ok=True)

SHUFFLE = True

CAUSAL_VARIABLES = [
    "pos_x1",
    "pos_y1",
    "pos_x2",
    "pos_y2",
    "pos_x3",
    "pos_y3",
]
# CAUSAL_VARIABLES = [
#     "pos_y1",
#     "pos_y2",
#     "pos_y3",
# ]

In [17]:
# NOTE: Index in this list determines value of intervention label
INTERVENTIONS = ["_empty_intervention"] + CAUSAL_VARIABLES

In [18]:
tags = [tag.stem for tag in sorted(DATASET_DIR.iterdir())]
tags

['dci_train', 'test', 'train', 'val']

In [19]:
import torch
from torch import nn


class IntractableError(Exception):
    """Exception thrown when quantities are fundamentally intractable"""

    pass


class Encoder(nn.Module):
    """
    Base class for encoders and decoders.
    """

    def __init__(self, input_features=2, output_features=2):
        super().__init__()
        self.input_features = input_features
        self.output_features = output_features

    def forward(self, inputs, deterministic=False):
        """
        Forward transformation.

        In an encoder: takes as input the observed data x and returns the latent representation z.
        In a decoder: takes as input the latent representation x and returns the reconstructed data
        x.

        Parameters:
        -----------
        inputs : torch.Tensor with shape (batchsize, input_features), dtype torch.float
            Data to be encoded or decoded

        Returns:
        --------
        outputs : torch.Tensor with shape (batchsize, output_features), dtype torch.float
            Encoded or decoded version of the data
        additional_info : None, torch.Tensor, or tuple
            Additional information that depends on the kind of encoder. For flow-style transforms,
            this is the log of the Jacobian determinant. For VAE encoders, this is the log
            likelihood or log posterior. Otherwise, None.
        """

        raise NotImplementedError

    def inverse(self, inputs, deterministic=False):
        """
        Inverse transformation, if tractable (otherwise raises an exception).

        In a decoder: takes as input the observed data x and returns the latent representation z.
        In an encoder: takes as input the latent representation z and returns the reconstructed data
        x.

        Parameters:
        -----------
        inputs : torch.Tensor with shape (batchsize, input_features), dtype torch.float
            Data to be encoded or decoded

        Returns:
        --------
        outputs : torch.Tensor with shape (batchsize, output_features), dtype torch.float
            Encoded or decoded version of the data
        additional_info: None or torch.Tensor
            Additional information that depends on the kind of encoder. For flow-style transforms,
            this is the log of the Jacobian determinant. Otherwise, None.
        """

        raise IntractableError()


class Inverse(Encoder):
    """
    Wrapper class that inverts the forward and inverse direction, e.g. turning an encoder into a
    decoder.
    """

    def __init__(self, base_model):
        super().__init__(
            input_features=base_model.output_features,
            output_features=base_model.input_features,
        )
        self.base_model = base_model

    def forward(self, inputs, deterministic=False):
        """
        Forward transformation.

        In an encoder: takes as input the observed data x and returns the latent representation z.
        In a decoder: takes as input the latent representation x and returns the reconstructed data
        x.

        Parameters:
        -----------
        inputs : torch.Tensor with shape (batchsize, input_features), dtype torch.float
            Data to be encoded or decoded

        Returns:
        --------
        outputs : torch.Tensor with shape (batchsize, output_features), dtype torch.float
            Encoded or decoded version of the data
        additional_info : None, torch.Tensor, or tuple
            Additional information that depends on the kind of encoder. For flow-style transforms,
            this is the log of the Jacobian determinant. For VAE encoders, this is the log
            likelihood or log posterior. Otherwise, None.
        """
        return self.base_model.inverse(inputs)

    def inverse(self, outputs, deterministic=False):
        """
        Inverse transformation, if tractable (otherwise raises an exception).

        In a decoder: takes as input the observed data x and returns the latent representation z.
        In an encoder: takes as input the latent representation z and returns the reconstructed data
        x.

        Parameters:
        -----------
        inputs : torch.Tensor with shape (batchsize, input_features), dtype torch.float
            Data to be encoded or decoded

        Returns:
        --------
        outputs : torch.Tensor with shape (batchsize, output_features), dtype torch.float
            Encoded or decoded version of the data
        additional_info: None or torch.Tensor
            Additional information that depends on the kind of encoder. For flow-style transforms,
            this is the log of the Jacobian determinant. Otherwise, None.
        """
        return self.base_model(outputs)


class SONEncoder(Encoder):
    """Deterministic SO(n) encoder / decoder"""

    def __init__(
        self, coeffs=None, input_features=2, output_features=2, coeff_std=0.05
    ):
        super().__init__(input_features, output_features)

        # Coefficients
        d = self.output_features * (self.output_features - 1) // 2
        if coeffs is None:
            self.coeffs = nn.Parameter(torch.zeros((d,)))  # (d)
            nn.init.normal_(self.coeffs, std=coeff_std)
        else:
            assert coeffs.shape == (d,)
            self.coeffs = nn.Parameter(coeffs)

        # Generators
        self.generators = torch.zeros((d, self.output_features, self.output_features))

    def forward(self, inputs, deterministic=False):
        """Given observed data, returns latent representation; i.e. encoding."""
        z = torch.einsum("ij,bj->bi", self._rotation_matrix(), inputs)
        logdet = torch.zeros([])
        return z, logdet

    def inverse(self, inputs, deterministic=False):
        """Given latent representation, returns observed version; i.e. decoding."""
        x = torch.einsum("ij,bj->bi", self._rotation_matrix(inverse=True), inputs)
        logdet = torch.zeros([])
        return x, logdet

    def _rotation_matrix(self, inverse=False):
        """
        Low-level function to generate an element of SO(n) by exponentiating the Lie algebra
        (skew-symmetric matrices)
        """

        o = torch.zeros(
            self.output_features, self.output_features, device=self.coeffs.device
        )
        i, j = torch.triu_indices(self.output_features, self.output_features, offset=1)
        if inverse:
            o[i, j] = -self.coeffs
            o.T[i, j] = self.coeffs
        else:
            o[i, j] = self.coeffs
            o.T[i, j] = -self.coeffs
        a = torch.matrix_exp(o)
        return a

In [20]:
class NoopEncoder(SONEncoder):
    def __init__(self, input_features=2, output_features=2):
        d = output_features * (output_features - 1) // 2
        coeffs = torch.zeros((d,))
        super().__init__(
            input_features=input_features,
            output_features=output_features,
            coeffs=coeffs,
        )

In [21]:
def sequence_to_batch(x_sequence):
    return x_sequence.reshape(-1, *x_sequence.shape[2:])  # (B*T, *x_dim)


def batch_to_sequence(x, batchsize):
    return x.reshape(batchsize, -1, x.shape[-1])  # (B, T, z_dim)

In [22]:
if ROTATE:
    encoder = SONEncoder(
        input_features=len(CAUSAL_VARIABLES), output_features=len(CAUSAL_VARIABLES)
    )
else:
    encoder = NoopEncoder(
        input_features=len(CAUSAL_VARIABLES), output_features=len(CAUSAL_VARIABLES)
    )
encoder

SONEncoder()

In [23]:
file_format = "pt"  # "npz" or "pt"
for tag in tags:
    print(tag)
    data = dict(np.load(DATASET_DIR / f"{tag}.npz"))
    sequence_length = data["original_latents"].shape[1]
    causal_images = []
    with torch.no_grad():
        for i in range(sequence_length):
            print(f"Encoding component {i}...")
            z = torch.tensor(data["original_latents"][:, i, :])
            causal_images_i, _ = encoder.inverse(z)
            causal_images.append(causal_images_i)
    causal_images = np.stack(causal_images, axis=1, dtype=np.float32)
    data["imgs"] = causal_images
    if file_format == "npz":
        np.savez_compressed(OUT_DIR / f"{tag}.npz", **data)
    elif file_format == "pt":
        output = (
            torch.tensor(data["imgs"][:, 0, :]),
            torch.tensor(data["imgs"][:, 1:, :]),
            torch.tensor(data["original_latents"][:, 0, :]),
            torch.tensor(data["original_latents"][:, 1:, :]),
            torch.tensor(data["intervention_labels"]),
            torch.tensor(data["intervention_masks"]),
            torch.tensor(data["epsilon"][:, 0, :]),
            torch.tensor(data["epsilon"][:, 1:, :]),
        )
        torch.save(output, OUT_DIR / f"{tag}_encoded.pt")
    torch.save(encoder.state_dict(), OUT_DIR / f"encoder.pt")

dci_train
Encoding component 0...
Encoding component 1...
test
Encoding component 0...
Encoding component 1...
train
Encoding component 0...
Encoding component 1...
val
Encoding component 0...
Encoding component 1...


## Sample fresh data set with 2dcubes SCM


In [46]:
# turn off torch gradient calculation
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x142dcdcf0>

In [47]:
def scm(e, WEIGHT1=1, WEIGHT2=0.7, NOISE_SCALE1=6, NOISE_SCALE2=6, NOISE_SCALE3=0.3):
    pos_x1 = e[:, 0] * NOISE_SCALE1
    pos_y1 = e[:, 1] * NOISE_SCALE1
    pos_x2 = e[:, 2] * NOISE_SCALE2
    pos_y2 = e[:, 3] * NOISE_SCALE2
    pos_x3 = (pos_x1 * WEIGHT1 + pos_x2 * WEIGHT2) / (WEIGHT1 + WEIGHT2) + e[
        :, 4
    ] * NOISE_SCALE3
    pos_y3 = (pos_y1 * WEIGHT1 + pos_y2 * WEIGHT2) / (WEIGHT1 + WEIGHT2) + e[
        :, 5
    ] * NOISE_SCALE3
    return torch.stack([pos_x1, pos_y1, pos_x2, pos_y2, pos_x3, pos_y3], dim=1)


def intervention_mechanism(e, NOISE_SCALE=6):
    return e * NOISE_SCALE

In [48]:
def uniform(shape, low=-1, high=1):
    return (low - high) * torch.rand(shape) + high

In [49]:
num_interventions = 7

In [51]:
tags = ["train", "val", "test", "dci_train"]
for tag, num_samples_per_intervention in zip(tags, [1500, 150, 150, 150]):
    print(
        f"Generating {num_samples_per_intervention} samples per intervention for {tag}..."
    )
    x1s = []
    x2s = []
    e1s = []
    e2s = []
    z1s = []
    z2s = []
    intervention_labels = []
    intervention_masks = []
    for intervention in range(7):
        e1 = uniform((num_samples_per_intervention, len(CAUSAL_VARIABLES)))
        z1 = scm(e1)
        x1 = encoder.inverse(z1)[0]
        if intervention == 0:
            e2 = e1.clone().detach()
            z2 = z1.clone().detach()
            x2 = x1.clone().detach()
        else:
            e2 = e1.clone().detach()
            # sample new noise for intervened variable
            e2[:, intervention - 1] = uniform((num_samples_per_intervention,))
            z2 = scm(e2)
            # overwrite intervened variable with other mechanism
            z2[:, intervention - 1] = intervention_mechanism(e2[:, intervention - 1])
            x2 = encoder.inverse(z2)[0]
        x1s.append(x1)
        x2s.append(x2)
        e1s.append(e1)
        e2s.append(e2)
        z1s.append(z1)
        z2s.append(z2)
        intervention_labels.append(
            torch.ones(num_samples_per_intervention) * intervention
        )
        # one only for intervened variable
        intervention_mask_for_intervention = torch.zeros(
            (num_samples_per_intervention, len(CAUSAL_VARIABLES))
        )
        intervention_mask_for_intervention[:, intervention - 1] = 1
        intervention_masks.append(intervention_mask_for_intervention)

    x1s = torch.cat(x1s, dim=0)
    x2s = torch.cat(x2s, dim=0)
    e1s = torch.cat(e1s, dim=0)
    e2s = torch.cat(e2s, dim=0)
    z1s = torch.cat(z1s, dim=0)
    z2s = torch.cat(z2s, dim=0)
    intervention_labels = torch.cat(intervention_labels, dim=0)
    intervention_masks = torch.cat(intervention_masks, dim=0)

    output = (
        x1s,
        x2s,
        z1s,
        z2s,
        intervention_labels,
        intervention_masks,
        e1s,
        e2s,
    )

    shuffle = True
    if shuffle:
        idx = torch.randperm(x1s.shape[0])
        output = tuple([o[idx] for o in output])

    OUT_DIR = Path("outputs/compressed") / "2d_cubes_generated"
    OUT_DIR.mkdir(parents=True, exist_ok=True)
    torch.save(output, OUT_DIR / f"{tag}_encoded.pt")

Generating 1500 samples per intervention for train...
Generating 150 samples per intervention for val...
Generating 150 samples per intervention for test...
Generating 150 samples per intervention for dci_train...


In [52]:
tags = ["train", "val", "test", "dci_train"]
intervention_ratios = [0.1, 0.1, 0.1, 0.1, 0.1, 0.25, 0.25]
for tag, total_num_samples in zip(tags, [10000, 1000, 1000, 1000]):
    print(f"Generating {total_num_samples} samples per intervention for {tag}...")
    x1s = []
    x2s = []
    e1s = []
    e2s = []
    z1s = []
    z2s = []
    intervention_labels = []
    intervention_masks = []
    for intervention in range(7):
        num_samples_for_intervention = int(
            total_num_samples * intervention_ratios[intervention]
        )
        e1 = uniform((num_samples_for_intervention, len(CAUSAL_VARIABLES)))
        z1 = scm(e1)
        x1 = encoder.inverse(z1)[0]
        if intervention == 0:
            e2 = e1.clone().detach()
            z2 = z1.clone().detach()
            x2 = x1.clone().detach()
        else:
            e2 = e1.clone().detach()
            # sample new noise for intervened variable
            e2[:, intervention - 1] = uniform((num_samples_for_intervention,))
            z2 = scm(e2)
            # overwrite intervened variable with other mechanism
            z2[:, intervention - 1] = intervention_mechanism(e2[:, intervention - 1])
            x2 = encoder.inverse(z2)[0]
        x1s.append(x1)
        x2s.append(x2)
        e1s.append(e1)
        e2s.append(e2)
        z1s.append(z1)
        z2s.append(z2)
        intervention_labels.append(
            torch.ones(num_samples_for_intervention) * intervention
        )
        # one only for intervened variable
        intervention_mask_for_intervention = torch.zeros(
            (num_samples_for_intervention, len(CAUSAL_VARIABLES))
        )
        intervention_mask_for_intervention[:, intervention - 1] = 1
        intervention_masks.append(intervention_mask_for_intervention)

    x1s = torch.cat(x1s, dim=0)
    x2s = torch.cat(x2s, dim=0)
    e1s = torch.cat(e1s, dim=0)
    e2s = torch.cat(e2s, dim=0)
    z1s = torch.cat(z1s, dim=0)
    z2s = torch.cat(z2s, dim=0)
    intervention_labels = torch.cat(intervention_labels, dim=0)
    intervention_masks = torch.cat(intervention_masks, dim=0)

    output = (
        x1s,
        x2s,
        z1s,
        z2s,
        intervention_labels,
        intervention_masks,
        e1s,
        e2s,
    )

    shuffle = True
    if shuffle:
        idx = torch.randperm(x1s.shape[0])
        output = tuple([o[idx] for o in output])

    OUT_DIR = Path("outputs/compressed") / "2d_cubes_generated_imbalanced"
    OUT_DIR.mkdir(parents=True, exist_ok=True)
    torch.save(output, OUT_DIR / f"{tag}_encoded.pt")

Generating 10000 samples per intervention for train...
Generating 1000 samples per intervention for val...
Generating 1000 samples per intervention for test...
Generating 1000 samples per intervention for dci_train...


In [42]:
x1s, x2s, z1s, z2s, intervention_labels, intervention_masks, e1s, e2s = torch.load(
    OUT_DIR / f"train_encoded.pt"
)