# Causal Data Augmentation
> Implementation for *"An Analysis of Causal Effect Estimation using Outcome Invariant Data Augmentation"* (NeurIPS 2025).

<p align="center">
    <img src="https://repository-images.githubusercontent.com/555395031/442c67c5-a047-42cf-96a8-1ad7d8587afa"
    alt="Symmetry as Intervention"
    width="42%">
</p>
<p align="center"><i>Symmetry as Intervention; Causal Estimation with Data Augmentation</i></p>
<p align="center">
  <a href="https://arxiv.org/abs/2510.25128"><img src="https://img.shields.io/badge/arXiv-2510.25128-B31B1B.svg?logo" /></a>
  <a href="https://neurips.cc/virtual/2025/poster/119327"><img src="https://img.shields.io/badge/html-%20neurips.cc-8c5cff.svg"></a>
  <a href="https://uzairakbar.github.io/causal-data-augmentation"><img src="https://img.shields.io/badge/WEB-page-0eb077.svg"></a>
  <a href="https://github.com/uzairakbar/causal-data-augmentation">
  <img src="https://img.shields.io/badge/uzairakbar-causal--data--augmentation-black?logo=github" alt="GitHub repository"></a>
</p>

## Contents
- [Setup](#setup)
- [Linear Experiment](#linear-experiment)
- [Optical Device Experiment](#optical-device-experiment)
- [Colored MNIST Experiment](#colored-mnist-experiment)
- [Citation](#citation)

## Setup

In [None]:
%%capture

# install dependencies and clone repository
!pip install -qq \
    enlighten \
    loguru \
    munch
!sudo apt install -qq \
    cm-super dvipng \
    texlive-latex-extra \
    texlive-latex-recommended
!git clone -b colab https://github.com/uzairakbar/causal-data-augmentation.git

# navigate to the repository directory
%cd causal-data-augmentation

In [None]:
import yaml
from munch import munchify


# load the experiment config file.
# NOTE: we are using a demo file here.
# change to `config.yaml` for full results.
with open('demo_config.yaml', 'r') as file:
    config = yaml.safe_load(file)
config = munchify(config)

## Linear Experiment

In [None]:
from loguru import logger

from src.experiments.simulation import (
    linear as linear_simulation,
)


if 'linear_simulation' in config:
    logger.info('Running linear simulation experiment.')
    linear_simulation.run(
        **config.linear_simulation,
        hyperparameters=config.hyperparameters
    )

In [None]:
# @title confounding strength vs. estimation error
from src.experiments.utils import (
    load,
    sweep_plot,
    ANNOTATE_SWEEP_PLOT,
)


kappa_values = load(
    'artifacts/linear_simulation/kappa_values.pkl'
)
kappa_results = load(
    'artifacts/linear_simulation/kappa_results.pkl'
)

sweep_plot(
    kappa_values, kappa_results, **ANNOTATE_SWEEP_PLOT['kappa']
)

In [None]:
# @title regularization strength vs. estimation error
alpha_values = load(
    'artifacts/linear_simulation/alpha_values.pkl'
)
alpha_results = load(
    'artifacts/linear_simulation/alpha_results.pkl'
)

sweep_plot(
    alpha_values, alpha_results, **ANNOTATE_SWEEP_PLOT['alpha']
)

In [None]:
# @title augmentation strength vs. estimation error
gamma_values = load(
    'artifacts/linear_simulation/gamma_values.pkl'
)
gamma_results = load(
    'artifacts/linear_simulation/gamma_results.pkl'
)

sweep_plot(
    gamma_values, gamma_results, **ANNOTATE_SWEEP_PLOT['gamma']
)

In [None]:
# @title comparison with baselines
from src.experiments.utils import (
    load,
    box_plot,
    ANNOTATE_BOX_PLOT,
)


baseline_results = load(
    'artifacts/linear_simulation/linear_simulation.pkl'
)

box_plot(
    baseline_results,
    fname='linear_simulation',
    experiment='linear_simulation',
    savefig=True, **ANNOTATE_BOX_PLOT['linear_simulation'],
)

**Note:** Use these configs for linear sim. baseline comparison results in the paper:
```yaml
# config.yaml
kernel_dim:     0
n_experiments:  16
```
```python
# src/sem/simulation/linear.py
COVARIATE_DIMENSION: int=16
```

## Optical Device Experiment

In [None]:
from src.experiments.real import (
    optical_device as optical_device_experiment,
)

if 'optical_device' in config:
    logger.info('Running optical device experiment.')
    optical_device_experiment.run(
        **config.optical_device,
        hyperparameters=config.hyperparameters,
    )

In [None]:
baseline_results = load(
    'artifacts/optical_device/optical_device.pkl'
)

box_plot(
    baseline_results,
    fname='optical_device',
    experiment='optical_device',
    savefig=True, **ANNOTATE_BOX_PLOT['optical_device'],
)

## Colored MNIST Experiment

In [None]:
from src.experiments.real import (
    cmnist as colored_mnist_experiment,
)

if 'colored_mnist' in config:
    logger.info('Running colored MNIST experiment.')
    colored_mnist_experiment.run(
        **config.colored_mnist,
        hyperparameters=config.hyperparameters,
    )

In [None]:
baseline_results = load(
    'artifacts/colored_mnist/colored_mnist.pkl'
)

box_plot(
    baseline_results,
    fname='colored_mnist',
    experiment='colored_mnist',
    savefig=True, **ANNOTATE_BOX_PLOT['colored_mnist'],
)

## Citation
Please cite our paper if you use this code in your work:
```bibtex
@misc{akbar2025causalDataAugmentation,
      title={An Analysis of Causal Effect Estimation using Outcome Invariant Data Augmentation},
      author={Uzair Akbar and Niki Kilbertus and Hao Shen and Krikamol Muandet and Bo Dai},
      year={2025},
      eprint={2510.25128},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2510.25128},
}
```