# Comparison of CCT- & CMC-learners and other approaches (simulations)

**On Alaa et al. (2023) synthetic data**

In [1]:
import pytensor

pytensor.config.optimizer = 'fast_compile'  # or 'None' for no optimizations
pytensor.config.exception_verbosity = 'high'

import sys
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
import warnings
import pandas as pd
from sklearn.model_selection import train_test_split
from crepes_weighted import WrapRegressor


from tqdm.notebook import tqdm


sys.path.append('../..')
from src.cmc_metalearners.cmc_metalearners import CMC_S_Learner, CMC_T_Learner, CMC_X_Learner, CCT_Learner
from src.wcp.wcp import NaiveWCP, NestedWCP
from src.conformal_metalearners.CM_learner import CM_learner
from src.datasets.alaa_synthetic import (generate_data)
from src.benchmarks.bart import BART
from src.benchmarks.cmgp import CMGP
from src.benchmarks.cevae import CEVAE
from src.benchmarks.fccn import FCCN
from src.benchmarks.ganite.ganite import GANITE
from src.benchmarks.dklite import DKLITE
from src.benchmarks.diffpo.main_model import DiffPOITE
from src.benchmarks.noflite.noflite import NOFLITE
import torch

## Comparison Setups

In [2]:
get_save_data = True
NSim = 100
alpha = 0.1
learner = RandomForestRegressor
learner_name = "RF"
MC_samples = 100
normalized_conformal = True
if normalized_conformal:
    normalized_conformal_name = "normalized"
else:
    normalized_conformal_name = "nonnormalized"
max_min_y = True

setup_A = {"n":5000, "d": 10, "gamma":1, "alpha": 0.1, "nexps": 1}
setup_B = {"n":5000, "d": 10, "gamma":0, "alpha": 0.1, "nexps": 1}

In [3]:
for i, setup in enumerate([setup_A, setup_B]):
    if i == 0:
        continue
    setup_name = "A" if i == 0 else "B"
    for n in tqdm(range(NSim)):
    # for n in tqdm(range(NSim)):
        path_train = f"../../data/simulations/alaa/setup{setup_name}/simulations_{setup_name}_{str(n)}_train.csv"
        path_test = f"../../data/simulations/alaa/setup{setup_name}/simulations_{setup_name}_{str(n)}_test.csv"
        if get_save_data:
            ds_train = pd.read_csv(path_train)
            ds_test = pd.read_csv(path_test)
        else:
            ds = generate_data(**setup)
            ds_train, ds_test = train_test_split(ds[0], test_size=0.6, random_state=42)
            ds_train.to_csv(path_train)
            ds_test.to_csv(path_test)
        W_train = ds_train['T'].to_numpy()
        y_train = ds_train['Y'].to_numpy()
        y1_train = ds_train['Y1'].to_numpy()
        y0_train = ds_train['Y0'].to_numpy()
        X_train = ds_train[['X'+str(i) for i in range(1,setup_A["d"]+1)]].to_numpy()
        ps_train = ds_train['ps'].to_numpy()
        ite_train = y1_train - y0_train

        W_test = ds_test['T'].to_numpy()
        y_test = ds_test['Y'].to_numpy()
        y1_test = ds_test['Y1'].to_numpy()
        y0_test = ds_test['Y0'].to_numpy()
        X_test = ds_test[['X'+str(i) for i in range(1,setup_A["d"]+1)]].to_numpy()
        ps_test = ds_test['ps'].to_numpy()
        ite_test = y1_test - y0_test

        # # Initialize the learner
        # conformal_pseudo_MC_T_Learner = CMC_T_Learner(
        #     learner(),
        #     learner(),
        #     normalized_conformal=normalized_conformal,
        #     pseudo_MC=True,
        #     MC_samples=MC_samples,
        #     max_min_y=max_min_y
        # )
        # conformal_pseudo_MC_T_Learner.fit(X_train, y_train, W_train, ps_train)

        # conformal_MC_T_Learner = CMC_T_Learner(
        #     learner(),
        #     learner(),
        #     normalized_conformal=normalized_conformal,
        #     pseudo_MC=False,
        #     MC_samples=MC_samples,
        #     max_min_y=max_min_y
        # )
        # conformal_MC_T_Learner.fit(X_train, y_train, W_train, ps_train)

        # conformal_pseudo_MC_S_Learner = CMC_S_Learner(
        #     learner(),
        #     normalized_conformal=normalized_conformal,
        #     pseudo_MC=True,
        #     MC_samples=MC_samples,
        #     max_min_y=max_min_y
        # )
        # with warnings.catch_warnings():
        #     conformal_pseudo_MC_S_Learner.fit(X_train, y_train, W_train, ps_train)

        # conformal_MC_S_Learner = CMC_S_Learner(
        #     learner(),
        #     normalized_conformal=normalized_conformal,
        #     pseudo_MC=False,
        #     MC_samples=MC_samples,
        #     max_min_y=max_min_y
        # )

        # with warnings.catch_warnings():
        #     conformal_MC_S_Learner.fit(X_train, y_train, W_train, ps_train)


        # conformal_pseudo_MC_X_Learner = CMC_X_Learner(
        #     learner(),
        #     learner(),
        #     learner(),
        #     normalized_conformal=normalized_conformal,
        #     pseudo_MC=True,
        #     MC_samples=MC_samples,
        #     max_min_y=max_min_y
        # )
        # # Fit the learner
        # with warnings.catch_warnings():
        #     warnings.simplefilter("ignore")
        #     conformal_pseudo_MC_X_Learner.fit(X_train, y_train, W_train, ps_train)
        # conformal_MC_X_Learner = CMC_X_Learner(
        #     learner(),
        #     learner(),
        #     learner(),
        #     normalized_conformal=normalized_conformal,
        #     pseudo_MC=False,
        #     MC_samples=MC_samples,
        #     max_min_y=max_min_y
        # )

        # # Fit the learner
        # with warnings.catch_warnings():
        #     warnings.simplefilter("ignore")
        #     conformal_MC_X_Learner.fit(X_train, y_train, W_train, ps_train)

        (X_train_nuisance, X_train_cal,
            y_train_nuisance, y_train_cal,
            y0_train_nuisance, y0_train_cal,
            y1_train_nuisance, y1_train_cal,
            W_train_nuisance, W_train_cal,
            ite_train_nuisance, ite_train_cal) = train_test_split(
                X_train, y_train, y0_train, y1_train, W_train, ite_train, test_size=0.5, random_state=n
        )

        conformal_y0 = WrapRegressor(learner())
        conformal_y0.fit(X_train_nuisance, y0_train_nuisance)
        conformal_y0.calibrate(X_train_cal, y0_train_cal, cps=True)
        conformal_y1 = WrapRegressor(learner())
        conformal_y1.fit(X_train_nuisance, y1_train_nuisance)
        conformal_y1.calibrate(X_train_cal, y1_train_cal, cps=True)
        conformal_ite = WrapRegressor(learner())
        conformal_ite.fit(X_train_nuisance, ite_train_nuisance)
        conformal_ite.calibrate(X_train_cal, ite_train_cal, cps=True)

        df_eval = pd.DataFrame(columns=["approach",
                                        "rmse_y0", "rmse_y1", "rmse_ite",
                                        "coverage_y0", "coverage_y1", "coverage_ite",
                                        "efficiency_y0", "efficiency_y1", "efficiency_ite",
                                        "crps_y0", "crps_y1", "crps_ite",
                                        "ll_y0", "ll_y1", "ll_ite"])

        # CCT-learner
        print("Fit and evaluate CCT-learner ...")
        conformal_CT_learner = CCT_Learner(learner(), learner(), normalized_conformal=normalized_conformal)
        conformal_CT_learner.fit(X_train, y_train, W_train, p=ps_train)
        evaluate = conformal_CT_learner.evaluate(X_test, y0_test, y1_test, ps_test, alpha=alpha)
        evaluate["approach"] = "CTT-learner"
        df_eval = pd.concat([df_eval, pd.DataFrame(evaluate, index=[0])], ignore_index=True)

        # # BART
        print("Fit and evaluate BART ...")
        bart = BART()
        bart.fit(X_train, y_train, W_train)
        evaluate = bart.evaluate(X_test, y0_test, y1_test, alpha=alpha)
        evaluate["approach"] = "BART"
        df_eval = pd.concat([df_eval, pd.DataFrame(evaluate, index=[0])], ignore_index=True)

        # CMGP
        print("Fit and evaluate CMGP ...")
        cmgp = CMGP(X=X_train, Treatments=W_train, Y=y_train)
        evaluate = cmgp.evaluate(X_test, y0_test, y1_test, alpha=alpha)
        evaluate["approach"] = "CMGP"
        df_eval = pd.concat([df_eval, pd.DataFrame(evaluate, index=[0])], ignore_index=True)

        # CEVAE
        print("Fit and evaluate CEVAE ...")
        dim_bin = 0
        dim_cont = X_train.shape[1]
        cevae = CEVAE(dim_bin=dim_bin, dim_cont=dim_cont)
        cevae.fit(X=X_train, Y=y_train, W=W_train)
        evaluate = cevae.evaluate(X_test, y0_test, y1_test, alpha=alpha)
        evaluate["approach"] = "CEVAE"
        df_eval = pd.concat([df_eval, pd.DataFrame(evaluate, index=[0])], ignore_index=True)

        # FCCN
        print("Fit and evaluate FCCN ...")
        fccn = FCCN(input_size=X_train.shape[1])
        fccn.train(X_train, y_train, W_train)
        evaluate = fccn.evaluate(X_test, y0_test, y1_test, alpha=alpha)
        evaluate["approach"] = "FCCN"
        df_eval = pd.concat([df_eval, pd.DataFrame(evaluate, index=[0])], ignore_index=True)

        # GANITE
        print("Fit and evaluate GANITE ...")
        ganite = ganite_params = {'h_dim': 30,                          # hidden dimensions
                                 'batch_size': 64,                     # the number of samples in each batch
                                 'iterations': 10000,                  # the number of iterations for training
                                 'alpha': 2.,
                                 'beta': 5.,                           # hyper-parameter to adjust the loss importance
                                 'input_size': X_train.shape[1],       # the number of features
                                 }
        ganite = GANITE(**ganite_params)
        ganite.fit(X_train, y_train, W_train)
        evaluate = ganite.evaluate(X_test, y0_test, y1_test, alpha=alpha)
        evaluate["approach"] = "GANITE"
        df_eval = pd.concat([df_eval, pd.DataFrame(evaluate, index=[0])], ignore_index=True)

        # DKLITE
        print("Fit and evaluate DKLITE ...")
        dklite = DKLITE(input_dim=X_train.shape[1], output_dim=1)
        dklite.fit(X_train, y_train, W_train)
        evaluate = dklite.evaluate(X_test, y0_test, y1_test, alpha=alpha)
        evaluate["approach"] = "DKLITE"
        df_eval = pd.concat([df_eval, pd.DataFrame(evaluate, index=[0])], ignore_index=True)

        # diffpo
        print("Fit and evaluate diffpo ...")
        config = {
            "train": {
                "epochs": 500,
                "batch_size": 256,
                "lr": 0.0005,
                "valid_epoch_interval": 50
            },
            "diffusion": {
                "layers": 4,
                "channels": 64,
                "f_dim": 180,
                "cond_dim": X_train.shape[1] + 1, # conditional variable dimension
                "hidden_dim": 128,
                "side_dim": 33,
                "nheads": 2,
                "diffusion_embedding_dim": 128,
                "beta_start": 0.0001,
                "beta_end": 0.5,
                "num_steps": 100,
                "schedule": "quad",
                "mixed": False
            },
            "model": {
                "is_unconditional": 0,
                "timeemb": 32,
                "featureemb": 16,
                "target_strategy": "random",
                "mixed": False
            },
        }
        device = "cuda" if torch.cuda.is_available() else "cpu"
        diffpo = DiffPOITE(config=config,
                           device=device)
        diffpo.fit(X=X_train, Y0=y0_train, Y1=y1_train, W=W_train, ps=ps_train)
        evaluate = diffpo.evaluate(X_test, y0_test, y1_test, W=W_test, ps=ps_test, alpha=alpha)
        evaluate["approach"] = "DiffPO"
        df_eval = pd.concat([df_eval, pd.DataFrame(evaluate, index=[0])], ignore_index=True)

        # NOFLITE
        print("Fit and evaluate NOFLITE ...")
        params = {
            'input_size': X_train.shape[1],
            'lr': 5e-4,
            'lambda_l1': 1e-3,
            'lambda_l2': 5e-3,
            'batch_size': 128,
            'noise_reg_x': 1e-0,
            'noise_reg_y': 5e-1,
            'hidden_neurons_encoder': 25,
            'hidden_layers_balancer': 3,
            'hidden_layers_encoder': 0,
            'hidden_layers_prior': 2,
            'hidden_neurons_trans': 4,
            'hidden_neurons_cond': 16,
            'hidden_layers_cond': 2,
            'dense': False,
            'n_flows': 1,
            'datapoint_num': 8,
            'resid_layers': 1,
            'max_steps': 10000,
            'flow_type': "SigmoidX",
            'metalearner': "T",
            'lambda_mmd': 0.1,
            'n_samples': 500,
            'trunc_prob': 0.01,
            'bin_outcome': False,
            'iterations': 1,
            'visualize': False,
        }
        noflite = NOFLITE(params=params)
        noflite.fit(X_train, y_train, W_train)
        evaluate = noflite.evaluate(X_test, y0_test, y1_test, W=W_test, alpha=alpha)
        evaluate["approach"] = "NOFLITE"
        df_eval = pd.concat([df_eval, pd.DataFrame(evaluate, index=[0])], ignore_index=True)

        if max_min_y:
            df_eval.to_csv(f"../../results/outputs/alaa/setup{setup_name}/eval_dist/simulations_{setup_name}_{str(n)}_{learner_name}_{normalized_conformal_name}_max_min_y_eval.csv", index=False)
        else:
            df_eval.to_csv(f"../../results/outputs/alaa/setup{setup_name}/eval_dist/simulations_{setup_name}_{str(n)}_{learner_name}_{normalized_conformal_name}_eval.csv", index=False)

  0%|          | 0/100 [00:00<?, ?it/s]

Fit and evaluate CCT-learner ...
Fit and evaluate BART ...


Multiprocess sampling (4 chains in 4 jobs)
PGBART: [mu]


Output()

Sampling 4 chains for 500 tune and 500 draw iterations (2_000 + 2_000 draws total) took 18 seconds.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
Sampling: [mu, y]


Output()

X_extended (3000, 11)


Sampling: [mu, y]


Output()

Sampling: [mu, y]


Output()

X_extended (3000, 11)


Sampling: [mu, y]


Output()

Fit and evaluate CMGP ...
Fit and evaluate CEVAE ...
Using cpu


100%|██████████| 7000/7000 [00:36<00:00, 191.40it/s]


Fit and evaluate FCCN ...


100%|██████████| 20000/20000 [02:13<00:00, 149.95it/s]


Fit and evaluate GANITE ...


Training Counterfactual GAN: 100%|██████████| 10000/10000 [08:21<00:00, 19.96it/s]
Training ITE GAN: 100%|██████████| 10000/10000 [08:50<00:00, 18.84it/s]


Fit and evaluate DKLITE ...


  5%|▌         | 51/1000 [00:02<00:48, 19.50it/s]


Fit and evaluate diffpo ...


100%|██████████| 7/7 [00:00<00:00, 16.56it/s, avg_epoch_loss=0, epoch=0]
100%|██████████| 7/7 [00:00<00:00, 16.94it/s, avg_epoch_loss=0, epoch=1]
100%|██████████| 7/7 [00:00<00:00, 16.96it/s, avg_epoch_loss=0, epoch=2]
100%|██████████| 7/7 [00:00<00:00, 17.14it/s, avg_epoch_loss=0, epoch=3]
100%|██████████| 7/7 [00:00<00:00, 16.98it/s, avg_epoch_loss=0, epoch=4]
100%|██████████| 7/7 [00:00<00:00, 15.45it/s, avg_epoch_loss=0, epoch=5]
100%|██████████| 7/7 [00:00<00:00, 17.37it/s, avg_epoch_loss=0, epoch=6]
100%|██████████| 7/7 [00:00<00:00, 16.58it/s, avg_epoch_loss=0, epoch=7]
100%|██████████| 7/7 [00:00<00:00, 16.27it/s, avg_epoch_loss=0, epoch=8]
100%|██████████| 7/7 [00:00<00:00, 18.26it/s, avg_epoch_loss=0, epoch=9]
100%|██████████| 7/7 [00:00<00:00, 18.50it/s, avg_epoch_loss=0, epoch=10]
100%|██████████| 7/7 [00:00<00:00, 17.60it/s, avg_epoch_loss=0, epoch=11]
100%|██████████| 7/7 [00:00<00:00, 17.46it/s, avg_epoch_loss=0, epoch=12]
100%|██████████| 7/7 [00:00<00:00, 17.73it/s, av

Start validation!!!
Epoch: 49


100%|██████████| 2/2 [01:06<00:00, 33.32s/it]


##### End evaluation!!
PEHE VAL = 8.1e+03


100%|██████████| 7/7 [00:00<00:00, 17.01it/s, avg_epoch_loss=0, epoch=50]
100%|██████████| 7/7 [00:00<00:00, 17.99it/s, avg_epoch_loss=0, epoch=51]
100%|██████████| 7/7 [00:00<00:00, 16.69it/s, avg_epoch_loss=0, epoch=52]
100%|██████████| 7/7 [00:00<00:00, 16.58it/s, avg_epoch_loss=0, epoch=53]
100%|██████████| 7/7 [00:00<00:00, 16.90it/s, avg_epoch_loss=0, epoch=54]
100%|██████████| 7/7 [00:00<00:00, 17.13it/s, avg_epoch_loss=0, epoch=55]
100%|██████████| 7/7 [00:00<00:00, 17.07it/s, avg_epoch_loss=0, epoch=56]
100%|██████████| 7/7 [00:00<00:00, 13.87it/s, avg_epoch_loss=0, epoch=57]
100%|██████████| 7/7 [00:00<00:00, 16.38it/s, avg_epoch_loss=0, epoch=58]
100%|██████████| 7/7 [00:00<00:00, 15.67it/s, avg_epoch_loss=0, epoch=59]
100%|██████████| 7/7 [00:00<00:00, 15.93it/s, avg_epoch_loss=0, epoch=60]
100%|██████████| 7/7 [00:00<00:00, 15.97it/s, avg_epoch_loss=0, epoch=61]
100%|██████████| 7/7 [00:00<00:00, 17.57it/s, avg_epoch_loss=0, epoch=62]
100%|██████████| 7/7 [00:00<00:00, 16.

Start validation!!!
Epoch: 99


100%|██████████| 2/2 [01:07<00:00, 33.75s/it]


##### End evaluation!!
PEHE VAL = 8.29e+03


100%|██████████| 7/7 [00:00<00:00, 16.74it/s, avg_epoch_loss=0, epoch=100]
100%|██████████| 7/7 [00:00<00:00, 16.91it/s, avg_epoch_loss=0, epoch=101]
100%|██████████| 7/7 [00:00<00:00, 16.91it/s, avg_epoch_loss=0, epoch=102]
100%|██████████| 7/7 [00:00<00:00, 16.75it/s, avg_epoch_loss=0, epoch=103]
100%|██████████| 7/7 [00:00<00:00, 17.05it/s, avg_epoch_loss=0, epoch=104]
100%|██████████| 7/7 [00:00<00:00, 16.77it/s, avg_epoch_loss=0, epoch=105]
100%|██████████| 7/7 [00:00<00:00, 16.86it/s, avg_epoch_loss=0, epoch=106]
100%|██████████| 7/7 [00:00<00:00, 16.94it/s, avg_epoch_loss=0, epoch=107]
100%|██████████| 7/7 [00:00<00:00, 13.82it/s, avg_epoch_loss=0, epoch=108]
100%|██████████| 7/7 [00:00<00:00, 17.18it/s, avg_epoch_loss=0, epoch=109]
100%|██████████| 7/7 [00:00<00:00, 17.55it/s, avg_epoch_loss=0, epoch=110]
100%|██████████| 7/7 [00:00<00:00, 17.52it/s, avg_epoch_loss=0, epoch=111]
100%|██████████| 7/7 [00:00<00:00, 17.67it/s, avg_epoch_loss=0, epoch=112]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 149


100%|██████████| 2/2 [01:08<00:00, 34.36s/it]


##### End evaluation!!
PEHE VAL = 8.48e+03


100%|██████████| 7/7 [00:00<00:00, 17.35it/s, avg_epoch_loss=0, epoch=150]
100%|██████████| 7/7 [00:00<00:00, 17.11it/s, avg_epoch_loss=0, epoch=151]
100%|██████████| 7/7 [00:00<00:00, 17.14it/s, avg_epoch_loss=0, epoch=152]
100%|██████████| 7/7 [00:00<00:00, 17.19it/s, avg_epoch_loss=0, epoch=153]
100%|██████████| 7/7 [00:00<00:00, 16.17it/s, avg_epoch_loss=0, epoch=154]
100%|██████████| 7/7 [00:00<00:00, 17.28it/s, avg_epoch_loss=0, epoch=155]
100%|██████████| 7/7 [00:00<00:00, 14.73it/s, avg_epoch_loss=0, epoch=156]
100%|██████████| 7/7 [00:00<00:00, 17.46it/s, avg_epoch_loss=0, epoch=157]
100%|██████████| 7/7 [00:00<00:00, 17.42it/s, avg_epoch_loss=0, epoch=158]
100%|██████████| 7/7 [00:00<00:00, 17.53it/s, avg_epoch_loss=0, epoch=159]
100%|██████████| 7/7 [00:00<00:00, 16.25it/s, avg_epoch_loss=0, epoch=160]
100%|██████████| 7/7 [00:00<00:00, 17.52it/s, avg_epoch_loss=0, epoch=161]
100%|██████████| 7/7 [00:00<00:00, 17.26it/s, avg_epoch_loss=0, epoch=162]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 199


100%|██████████| 2/2 [01:07<00:00, 33.72s/it]


##### End evaluation!!
PEHE VAL = 8.64e+03


100%|██████████| 7/7 [00:00<00:00, 17.04it/s, avg_epoch_loss=0, epoch=200]
100%|██████████| 7/7 [00:00<00:00, 17.00it/s, avg_epoch_loss=0, epoch=201]
100%|██████████| 7/7 [00:00<00:00, 17.00it/s, avg_epoch_loss=0, epoch=202]
100%|██████████| 7/7 [00:00<00:00, 17.08it/s, avg_epoch_loss=0, epoch=203]
100%|██████████| 7/7 [00:00<00:00, 17.13it/s, avg_epoch_loss=0, epoch=204]
100%|██████████| 7/7 [00:00<00:00, 16.85it/s, avg_epoch_loss=0, epoch=205]
100%|██████████| 7/7 [00:00<00:00, 17.11it/s, avg_epoch_loss=0, epoch=206]
100%|██████████| 7/7 [00:00<00:00, 17.05it/s, avg_epoch_loss=0, epoch=207]
100%|██████████| 7/7 [00:00<00:00, 14.39it/s, avg_epoch_loss=0, epoch=208]
100%|██████████| 7/7 [00:00<00:00, 17.51it/s, avg_epoch_loss=0, epoch=209]
100%|██████████| 7/7 [00:00<00:00, 17.63it/s, avg_epoch_loss=0, epoch=210]
100%|██████████| 7/7 [00:00<00:00, 17.25it/s, avg_epoch_loss=0, epoch=211]
100%|██████████| 7/7 [00:00<00:00, 17.50it/s, avg_epoch_loss=0, epoch=212]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 249


100%|██████████| 2/2 [01:07<00:00, 33.60s/it]


##### End evaluation!!
PEHE VAL = 8.31e+03


100%|██████████| 7/7 [00:00<00:00, 17.58it/s, avg_epoch_loss=0, epoch=250]
100%|██████████| 7/7 [00:00<00:00, 15.16it/s, avg_epoch_loss=0, epoch=251]
100%|██████████| 7/7 [00:00<00:00, 13.51it/s, avg_epoch_loss=0, epoch=252]
100%|██████████| 7/7 [00:00<00:00, 15.99it/s, avg_epoch_loss=0, epoch=253]
100%|██████████| 7/7 [00:00<00:00, 13.94it/s, avg_epoch_loss=0, epoch=254]
100%|██████████| 7/7 [00:00<00:00, 16.60it/s, avg_epoch_loss=0, epoch=255]
100%|██████████| 7/7 [00:00<00:00, 17.10it/s, avg_epoch_loss=0, epoch=256]
100%|██████████| 7/7 [00:00<00:00, 17.29it/s, avg_epoch_loss=0, epoch=257]
100%|██████████| 7/7 [00:00<00:00, 13.80it/s, avg_epoch_loss=0, epoch=258]
100%|██████████| 7/7 [00:00<00:00, 16.60it/s, avg_epoch_loss=0, epoch=259]
100%|██████████| 7/7 [00:00<00:00, 16.63it/s, avg_epoch_loss=0, epoch=260]
100%|██████████| 7/7 [00:00<00:00, 14.01it/s, avg_epoch_loss=0, epoch=261]
100%|██████████| 7/7 [00:00<00:00, 14.96it/s, avg_epoch_loss=0, epoch=262]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 299


100%|██████████| 2/2 [01:12<00:00, 36.43s/it]


##### End evaluation!!
PEHE VAL = 8.37e+03


100%|██████████| 7/7 [00:00<00:00, 16.98it/s, avg_epoch_loss=0, epoch=300]
100%|██████████| 7/7 [00:00<00:00, 17.32it/s, avg_epoch_loss=0, epoch=301]
100%|██████████| 7/7 [00:00<00:00, 17.25it/s, avg_epoch_loss=0, epoch=302]
100%|██████████| 7/7 [00:00<00:00, 16.64it/s, avg_epoch_loss=0, epoch=303]
100%|██████████| 7/7 [00:00<00:00, 16.58it/s, avg_epoch_loss=0, epoch=304]
100%|██████████| 7/7 [00:00<00:00, 16.77it/s, avg_epoch_loss=0, epoch=305]
100%|██████████| 7/7 [00:00<00:00, 17.65it/s, avg_epoch_loss=0, epoch=306]
100%|██████████| 7/7 [00:00<00:00, 18.11it/s, avg_epoch_loss=0, epoch=307]
100%|██████████| 7/7 [00:00<00:00, 18.36it/s, avg_epoch_loss=0, epoch=308]
100%|██████████| 7/7 [00:00<00:00, 17.94it/s, avg_epoch_loss=0, epoch=309]
100%|██████████| 7/7 [00:00<00:00, 18.30it/s, avg_epoch_loss=0, epoch=310]
100%|██████████| 7/7 [00:00<00:00, 17.27it/s, avg_epoch_loss=0, epoch=311]
100%|██████████| 7/7 [00:00<00:00, 17.43it/s, avg_epoch_loss=0, epoch=312]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 349


100%|██████████| 2/2 [01:11<00:00, 35.86s/it]


##### End evaluation!!
PEHE VAL = 8.91e+03


100%|██████████| 7/7 [00:00<00:00, 13.36it/s, avg_epoch_loss=0, epoch=350]
100%|██████████| 7/7 [00:00<00:00, 17.30it/s, avg_epoch_loss=0, epoch=351]
100%|██████████| 7/7 [00:00<00:00, 17.85it/s, avg_epoch_loss=0, epoch=352]
100%|██████████| 7/7 [00:00<00:00, 17.46it/s, avg_epoch_loss=0, epoch=353]
100%|██████████| 7/7 [00:00<00:00, 17.23it/s, avg_epoch_loss=0, epoch=354]
100%|██████████| 7/7 [00:00<00:00, 13.22it/s, avg_epoch_loss=0, epoch=355]
100%|██████████| 7/7 [00:00<00:00, 13.51it/s, avg_epoch_loss=0, epoch=356]
100%|██████████| 7/7 [00:00<00:00, 14.34it/s, avg_epoch_loss=0, epoch=357]
100%|██████████| 7/7 [00:00<00:00, 17.71it/s, avg_epoch_loss=0, epoch=358]
100%|██████████| 7/7 [00:00<00:00, 17.80it/s, avg_epoch_loss=0, epoch=359]
100%|██████████| 7/7 [00:00<00:00, 18.16it/s, avg_epoch_loss=0, epoch=360]
100%|██████████| 7/7 [00:00<00:00, 15.04it/s, avg_epoch_loss=0, epoch=361]
100%|██████████| 7/7 [00:00<00:00, 15.27it/s, avg_epoch_loss=0, epoch=362]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 399


100%|██████████| 2/2 [01:04<00:00, 32.49s/it]


##### End evaluation!!
PEHE VAL = 8.9e+03


100%|██████████| 7/7 [00:00<00:00, 16.74it/s, avg_epoch_loss=0, epoch=400]
100%|██████████| 7/7 [00:00<00:00, 16.94it/s, avg_epoch_loss=0, epoch=401]
100%|██████████| 7/7 [00:00<00:00, 17.24it/s, avg_epoch_loss=0, epoch=402]
100%|██████████| 7/7 [00:00<00:00, 17.03it/s, avg_epoch_loss=0, epoch=403]
100%|██████████| 7/7 [00:00<00:00, 16.94it/s, avg_epoch_loss=0, epoch=404]
100%|██████████| 7/7 [00:00<00:00, 17.30it/s, avg_epoch_loss=0, epoch=405]
100%|██████████| 7/7 [00:00<00:00, 17.10it/s, avg_epoch_loss=0, epoch=406]
100%|██████████| 7/7 [00:00<00:00, 17.28it/s, avg_epoch_loss=0, epoch=407]
100%|██████████| 7/7 [00:00<00:00, 17.20it/s, avg_epoch_loss=0, epoch=408]
100%|██████████| 7/7 [00:00<00:00, 16.89it/s, avg_epoch_loss=0, epoch=409]
100%|██████████| 7/7 [00:00<00:00, 18.10it/s, avg_epoch_loss=0, epoch=410]
100%|██████████| 7/7 [00:00<00:00, 17.89it/s, avg_epoch_loss=0, epoch=411]
100%|██████████| 7/7 [00:00<00:00, 18.16it/s, avg_epoch_loss=0, epoch=412]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 449


100%|██████████| 2/2 [01:11<00:00, 35.62s/it]


##### End evaluation!!
PEHE VAL = 8.03e+03


100%|██████████| 7/7 [00:00<00:00, 14.36it/s, avg_epoch_loss=0, epoch=450]
100%|██████████| 7/7 [00:00<00:00, 14.30it/s, avg_epoch_loss=0, epoch=451]
100%|██████████| 7/7 [00:00<00:00, 12.47it/s, avg_epoch_loss=0, epoch=452]
100%|██████████| 7/7 [00:00<00:00, 15.03it/s, avg_epoch_loss=0, epoch=453]
100%|██████████| 7/7 [00:00<00:00, 15.57it/s, avg_epoch_loss=0, epoch=454]
100%|██████████| 7/7 [00:00<00:00, 15.02it/s, avg_epoch_loss=0, epoch=455]
100%|██████████| 7/7 [00:00<00:00, 12.68it/s, avg_epoch_loss=0, epoch=456]
100%|██████████| 7/7 [00:00<00:00, 15.02it/s, avg_epoch_loss=0, epoch=457]
100%|██████████| 7/7 [00:00<00:00, 15.40it/s, avg_epoch_loss=0, epoch=458]
100%|██████████| 7/7 [00:00<00:00, 16.62it/s, avg_epoch_loss=0, epoch=459]
100%|██████████| 7/7 [00:00<00:00, 18.09it/s, avg_epoch_loss=0, epoch=460]
100%|██████████| 7/7 [00:00<00:00, 16.70it/s, avg_epoch_loss=0, epoch=461]
100%|██████████| 7/7 [00:00<00:00, 15.58it/s, avg_epoch_loss=0, epoch=462]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 499


100%|██████████| 2/2 [01:05<00:00, 32.60s/it]


##### End evaluation!!
PEHE VAL = 8.54e+03
Training complete.
Evaluating the model...


100%|██████████| 12/12 [14:20<00:00, 71.69s/it]
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name           | Type        | Params | Mode 
-------------------------------------------------------
0 | balancer       | Sequential  | 1.6 K  | train
1 | prior_encoder0 | Sequential  | 0      | train
2 | prior_encoder1 | Sequential  | 0      | train
3 | cond_mean0     | Sequential  | 701    | train
4 | cond_mean1     | Sequential  | 701    | train
5 | cond_std0      | Sequential  | 701    | train
6 | cond_std1      | Sequential  | 701    | train
7 | flows0         | DSFMarginal | 892    | train
8 | flows1         | DSFMarginal | 892    | train
-------------------------------------------------------
6.2 K     Trainable params
0         Non-trainable params
6.2 K     Total params
0.025     Total estimated model params size (MB)
45        Modules in train mode
0         Modules in eval mode


Fit and evaluate NOFLITE ...


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_steps=10000` reached.


Fit and evaluate BART ...


Multiprocess sampling (4 chains in 4 jobs)
PGBART: [mu]


Output()

Sampling 4 chains for 500 tune and 500 draw iterations (2_000 + 2_000 draws total) took 20 seconds.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
Sampling: [mu, y]


Output()

X_extended (3000, 11)


Sampling: [mu, y]


Output()

Sampling: [mu, y]


Output()

X_extended (3000, 11)


Sampling: [mu, y]


Output()

Fit and evaluate CMGP ...
Fit and evaluate CEVAE ...
Using cpu


100%|██████████| 7000/7000 [00:37<00:00, 185.58it/s]


Fit and evaluate FCCN ...


100%|██████████| 20000/20000 [02:18<00:00, 144.56it/s]


Fit and evaluate GANITE ...


Training Counterfactual GAN: 100%|██████████| 10000/10000 [08:19<00:00, 20.01it/s]
Training ITE GAN: 100%|██████████| 10000/10000 [09:31<00:00, 17.48it/s]


Fit and evaluate DKLITE ...


  5%|▌         | 51/1000 [00:02<00:49, 19.35it/s]


Fit and evaluate diffpo ...


100%|██████████| 7/7 [00:00<00:00, 17.48it/s, avg_epoch_loss=0, epoch=0]
100%|██████████| 7/7 [00:00<00:00, 16.43it/s, avg_epoch_loss=0, epoch=1]
100%|██████████| 7/7 [00:00<00:00, 15.35it/s, avg_epoch_loss=0, epoch=2]
100%|██████████| 7/7 [00:00<00:00, 17.05it/s, avg_epoch_loss=0, epoch=3]
100%|██████████| 7/7 [00:00<00:00, 16.82it/s, avg_epoch_loss=0, epoch=4]
100%|██████████| 7/7 [00:00<00:00, 14.59it/s, avg_epoch_loss=0, epoch=5]
100%|██████████| 7/7 [00:00<00:00, 16.11it/s, avg_epoch_loss=0, epoch=6]
100%|██████████| 7/7 [00:00<00:00, 15.65it/s, avg_epoch_loss=0, epoch=7]
100%|██████████| 7/7 [00:00<00:00, 14.95it/s, avg_epoch_loss=0, epoch=8]
100%|██████████| 7/7 [00:00<00:00, 18.19it/s, avg_epoch_loss=0, epoch=9]
100%|██████████| 7/7 [00:00<00:00, 17.97it/s, avg_epoch_loss=0, epoch=10]
100%|██████████| 7/7 [00:00<00:00, 17.80it/s, avg_epoch_loss=0, epoch=11]
100%|██████████| 7/7 [00:00<00:00, 17.71it/s, avg_epoch_loss=0, epoch=12]
100%|██████████| 7/7 [00:00<00:00, 18.11it/s, av

Start validation!!!
Epoch: 49


100%|██████████| 2/2 [00:58<00:00, 29.49s/it]


##### End evaluation!!
PEHE VAL = 8.1e+03


100%|██████████| 7/7 [00:00<00:00, 17.86it/s, avg_epoch_loss=0, epoch=50]
100%|██████████| 7/7 [00:00<00:00, 17.63it/s, avg_epoch_loss=0, epoch=51]
100%|██████████| 7/7 [00:00<00:00, 18.02it/s, avg_epoch_loss=0, epoch=52]
100%|██████████| 7/7 [00:00<00:00, 17.90it/s, avg_epoch_loss=0, epoch=53]
100%|██████████| 7/7 [00:00<00:00, 17.91it/s, avg_epoch_loss=0, epoch=54]
100%|██████████| 7/7 [00:00<00:00, 17.84it/s, avg_epoch_loss=0, epoch=55]
100%|██████████| 7/7 [00:00<00:00, 17.59it/s, avg_epoch_loss=0, epoch=56]
100%|██████████| 7/7 [00:00<00:00, 18.10it/s, avg_epoch_loss=0, epoch=57]
100%|██████████| 7/7 [00:00<00:00, 17.93it/s, avg_epoch_loss=0, epoch=58]
100%|██████████| 7/7 [00:00<00:00, 17.94it/s, avg_epoch_loss=0, epoch=59]
100%|██████████| 7/7 [00:00<00:00, 18.16it/s, avg_epoch_loss=0, epoch=60]
100%|██████████| 7/7 [00:00<00:00, 18.15it/s, avg_epoch_loss=0, epoch=61]
100%|██████████| 7/7 [00:00<00:00, 13.98it/s, avg_epoch_loss=0, epoch=62]
100%|██████████| 7/7 [00:00<00:00, 18.

Start validation!!!
Epoch: 99


100%|██████████| 2/2 [00:53<00:00, 26.99s/it]


##### End evaluation!!
PEHE VAL = 8.29e+03


100%|██████████| 7/7 [00:00<00:00, 16.26it/s, avg_epoch_loss=0, epoch=100]
100%|██████████| 7/7 [00:00<00:00, 18.15it/s, avg_epoch_loss=0, epoch=101]
100%|██████████| 7/7 [00:00<00:00, 17.79it/s, avg_epoch_loss=0, epoch=102]
100%|██████████| 7/7 [00:00<00:00, 18.38it/s, avg_epoch_loss=0, epoch=103]
100%|██████████| 7/7 [00:00<00:00, 18.46it/s, avg_epoch_loss=0, epoch=104]
100%|██████████| 7/7 [00:00<00:00, 17.78it/s, avg_epoch_loss=0, epoch=105]
100%|██████████| 7/7 [00:00<00:00, 18.43it/s, avg_epoch_loss=0, epoch=106]
100%|██████████| 7/7 [00:00<00:00, 17.97it/s, avg_epoch_loss=0, epoch=107]
100%|██████████| 7/7 [00:00<00:00, 18.34it/s, avg_epoch_loss=0, epoch=108]
100%|██████████| 7/7 [00:00<00:00, 18.24it/s, avg_epoch_loss=0, epoch=109]
100%|██████████| 7/7 [00:00<00:00, 18.12it/s, avg_epoch_loss=0, epoch=110]
100%|██████████| 7/7 [00:00<00:00, 18.40it/s, avg_epoch_loss=0, epoch=111]
100%|██████████| 7/7 [00:00<00:00, 17.61it/s, avg_epoch_loss=0, epoch=112]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 149


100%|██████████| 2/2 [00:54<00:00, 27.38s/it]


##### End evaluation!!
PEHE VAL = 8.48e+03


100%|██████████| 7/7 [00:00<00:00, 14.88it/s, avg_epoch_loss=0, epoch=150]
100%|██████████| 7/7 [00:00<00:00, 18.36it/s, avg_epoch_loss=0, epoch=151]
100%|██████████| 7/7 [00:00<00:00, 15.38it/s, avg_epoch_loss=0, epoch=152]
100%|██████████| 7/7 [00:00<00:00, 16.75it/s, avg_epoch_loss=0, epoch=153]
100%|██████████| 7/7 [00:00<00:00, 17.10it/s, avg_epoch_loss=0, epoch=154]
100%|██████████| 7/7 [00:00<00:00, 18.18it/s, avg_epoch_loss=0, epoch=155]
100%|██████████| 7/7 [00:00<00:00, 14.55it/s, avg_epoch_loss=0, epoch=156]
100%|██████████| 7/7 [00:00<00:00, 13.75it/s, avg_epoch_loss=0, epoch=157]
100%|██████████| 7/7 [00:00<00:00, 13.87it/s, avg_epoch_loss=0, epoch=158]
100%|██████████| 7/7 [00:00<00:00, 18.02it/s, avg_epoch_loss=0, epoch=159]
100%|██████████| 7/7 [00:00<00:00, 18.46it/s, avg_epoch_loss=0, epoch=160]
100%|██████████| 7/7 [00:00<00:00, 18.69it/s, avg_epoch_loss=0, epoch=161]
100%|██████████| 7/7 [00:00<00:00, 17.88it/s, avg_epoch_loss=0, epoch=162]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 199


100%|██████████| 2/2 [00:54<00:00, 27.27s/it]


##### End evaluation!!
PEHE VAL = 8.64e+03


100%|██████████| 7/7 [00:00<00:00, 18.15it/s, avg_epoch_loss=0, epoch=200]
100%|██████████| 7/7 [00:00<00:00, 18.21it/s, avg_epoch_loss=0, epoch=201]
100%|██████████| 7/7 [00:00<00:00, 18.00it/s, avg_epoch_loss=0, epoch=202]
100%|██████████| 7/7 [00:00<00:00, 18.30it/s, avg_epoch_loss=0, epoch=203]
100%|██████████| 7/7 [00:00<00:00, 18.07it/s, avg_epoch_loss=0, epoch=204]
100%|██████████| 7/7 [00:00<00:00, 18.10it/s, avg_epoch_loss=0, epoch=205]
100%|██████████| 7/7 [00:00<00:00, 18.51it/s, avg_epoch_loss=0, epoch=206]
100%|██████████| 7/7 [00:00<00:00, 15.13it/s, avg_epoch_loss=0, epoch=207]
100%|██████████| 7/7 [00:00<00:00, 13.06it/s, avg_epoch_loss=0, epoch=208]
100%|██████████| 7/7 [00:00<00:00, 17.51it/s, avg_epoch_loss=0, epoch=209]
100%|██████████| 7/7 [00:00<00:00, 16.48it/s, avg_epoch_loss=0, epoch=210]
100%|██████████| 7/7 [00:00<00:00, 17.85it/s, avg_epoch_loss=0, epoch=211]
100%|██████████| 7/7 [00:00<00:00, 13.61it/s, avg_epoch_loss=0, epoch=212]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 249


100%|██████████| 2/2 [00:53<00:00, 26.74s/it]


##### End evaluation!!
PEHE VAL = 8.31e+03


100%|██████████| 7/7 [00:00<00:00, 18.37it/s, avg_epoch_loss=0, epoch=250]
100%|██████████| 7/7 [00:00<00:00, 13.37it/s, avg_epoch_loss=0, epoch=251]
100%|██████████| 7/7 [00:00<00:00, 17.37it/s, avg_epoch_loss=0, epoch=252]
100%|██████████| 7/7 [00:00<00:00, 18.64it/s, avg_epoch_loss=0, epoch=253]
100%|██████████| 7/7 [00:00<00:00, 17.10it/s, avg_epoch_loss=0, epoch=254]
100%|██████████| 7/7 [00:00<00:00, 14.51it/s, avg_epoch_loss=0, epoch=255]
100%|██████████| 7/7 [00:00<00:00, 18.02it/s, avg_epoch_loss=0, epoch=256]
100%|██████████| 7/7 [00:00<00:00, 18.13it/s, avg_epoch_loss=0, epoch=257]
100%|██████████| 7/7 [00:00<00:00, 12.00it/s, avg_epoch_loss=0, epoch=258]
100%|██████████| 7/7 [00:00<00:00, 18.14it/s, avg_epoch_loss=0, epoch=259]
100%|██████████| 7/7 [00:00<00:00, 18.40it/s, avg_epoch_loss=0, epoch=260]
100%|██████████| 7/7 [00:00<00:00, 18.57it/s, avg_epoch_loss=0, epoch=261]
100%|██████████| 7/7 [00:00<00:00, 18.35it/s, avg_epoch_loss=0, epoch=262]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 299


100%|██████████| 2/2 [00:54<00:00, 27.26s/it]


##### End evaluation!!
PEHE VAL = 8.37e+03


100%|██████████| 7/7 [00:00<00:00, 18.18it/s, avg_epoch_loss=0, epoch=300]
100%|██████████| 7/7 [00:00<00:00, 15.92it/s, avg_epoch_loss=0, epoch=301]
100%|██████████| 7/7 [00:00<00:00, 17.35it/s, avg_epoch_loss=0, epoch=302]
100%|██████████| 7/7 [00:00<00:00, 15.28it/s, avg_epoch_loss=0, epoch=303]
100%|██████████| 7/7 [00:00<00:00, 15.91it/s, avg_epoch_loss=0, epoch=304]
100%|██████████| 7/7 [00:00<00:00, 13.90it/s, avg_epoch_loss=0, epoch=305]
100%|██████████| 7/7 [00:00<00:00, 18.46it/s, avg_epoch_loss=0, epoch=306]
100%|██████████| 7/7 [00:00<00:00, 18.36it/s, avg_epoch_loss=0, epoch=307]
100%|██████████| 7/7 [00:00<00:00, 14.45it/s, avg_epoch_loss=0, epoch=308]
100%|██████████| 7/7 [00:00<00:00, 15.76it/s, avg_epoch_loss=0, epoch=309]
100%|██████████| 7/7 [00:00<00:00, 16.34it/s, avg_epoch_loss=0, epoch=310]
100%|██████████| 7/7 [00:00<00:00, 18.47it/s, avg_epoch_loss=0, epoch=311]
100%|██████████| 7/7 [00:00<00:00, 15.14it/s, avg_epoch_loss=0, epoch=312]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 349


100%|██████████| 2/2 [00:54<00:00, 27.32s/it]


##### End evaluation!!
PEHE VAL = 8.91e+03


100%|██████████| 7/7 [00:00<00:00, 18.58it/s, avg_epoch_loss=0, epoch=350]
100%|██████████| 7/7 [00:00<00:00, 18.26it/s, avg_epoch_loss=0, epoch=351]
100%|██████████| 7/7 [00:00<00:00, 18.34it/s, avg_epoch_loss=0, epoch=352]
100%|██████████| 7/7 [00:00<00:00, 17.83it/s, avg_epoch_loss=0, epoch=353]
100%|██████████| 7/7 [00:00<00:00, 18.40it/s, avg_epoch_loss=0, epoch=354]
100%|██████████| 7/7 [00:00<00:00, 17.77it/s, avg_epoch_loss=0, epoch=355]
100%|██████████| 7/7 [00:00<00:00, 18.25it/s, avg_epoch_loss=0, epoch=356]
100%|██████████| 7/7 [00:00<00:00, 18.37it/s, avg_epoch_loss=0, epoch=357]
100%|██████████| 7/7 [00:00<00:00, 17.59it/s, avg_epoch_loss=0, epoch=358]
100%|██████████| 7/7 [00:00<00:00, 12.80it/s, avg_epoch_loss=0, epoch=359]
100%|██████████| 7/7 [00:00<00:00, 15.82it/s, avg_epoch_loss=0, epoch=360]
100%|██████████| 7/7 [00:00<00:00, 17.34it/s, avg_epoch_loss=0, epoch=361]
100%|██████████| 7/7 [00:00<00:00, 17.90it/s, avg_epoch_loss=0, epoch=362]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 399


100%|██████████| 2/2 [00:54<00:00, 27.37s/it]


##### End evaluation!!
PEHE VAL = 8.9e+03


100%|██████████| 7/7 [00:00<00:00, 14.06it/s, avg_epoch_loss=0, epoch=400]
100%|██████████| 7/7 [00:00<00:00, 18.18it/s, avg_epoch_loss=0, epoch=401]
100%|██████████| 7/7 [00:00<00:00, 18.34it/s, avg_epoch_loss=0, epoch=402]
100%|██████████| 7/7 [00:00<00:00, 18.23it/s, avg_epoch_loss=0, epoch=403]
100%|██████████| 7/7 [00:00<00:00, 18.21it/s, avg_epoch_loss=0, epoch=404]
100%|██████████| 7/7 [00:00<00:00, 15.73it/s, avg_epoch_loss=0, epoch=405]
100%|██████████| 7/7 [00:00<00:00, 18.15it/s, avg_epoch_loss=0, epoch=406]
100%|██████████| 7/7 [00:00<00:00, 18.28it/s, avg_epoch_loss=0, epoch=407]
100%|██████████| 7/7 [00:00<00:00, 18.09it/s, avg_epoch_loss=0, epoch=408]
100%|██████████| 7/7 [00:00<00:00, 18.17it/s, avg_epoch_loss=0, epoch=409]
100%|██████████| 7/7 [00:00<00:00, 18.28it/s, avg_epoch_loss=0, epoch=410]
100%|██████████| 7/7 [00:00<00:00, 18.00it/s, avg_epoch_loss=0, epoch=411]
100%|██████████| 7/7 [00:00<00:00, 18.72it/s, avg_epoch_loss=0, epoch=412]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 449


100%|██████████| 2/2 [00:53<00:00, 26.66s/it]


##### End evaluation!!
PEHE VAL = 8.03e+03


100%|██████████| 7/7 [00:00<00:00, 18.17it/s, avg_epoch_loss=0, epoch=450]
100%|██████████| 7/7 [00:00<00:00, 14.24it/s, avg_epoch_loss=0, epoch=451]
100%|██████████| 7/7 [00:00<00:00, 18.28it/s, avg_epoch_loss=0, epoch=452]
100%|██████████| 7/7 [00:00<00:00, 18.07it/s, avg_epoch_loss=0, epoch=453]
100%|██████████| 7/7 [00:00<00:00, 18.67it/s, avg_epoch_loss=0, epoch=454]
100%|██████████| 7/7 [00:00<00:00, 17.54it/s, avg_epoch_loss=0, epoch=455]
100%|██████████| 7/7 [00:00<00:00, 18.20it/s, avg_epoch_loss=0, epoch=456]
100%|██████████| 7/7 [00:00<00:00, 17.01it/s, avg_epoch_loss=0, epoch=457]
100%|██████████| 7/7 [00:00<00:00, 17.99it/s, avg_epoch_loss=0, epoch=458]
100%|██████████| 7/7 [00:00<00:00, 18.12it/s, avg_epoch_loss=0, epoch=459]
100%|██████████| 7/7 [00:00<00:00, 14.29it/s, avg_epoch_loss=0, epoch=460]
100%|██████████| 7/7 [00:00<00:00, 16.15it/s, avg_epoch_loss=0, epoch=461]
100%|██████████| 7/7 [00:00<00:00, 14.33it/s, avg_epoch_loss=0, epoch=462]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 499


100%|██████████| 2/2 [00:54<00:00, 27.21s/it]


##### End evaluation!!
PEHE VAL = 8.54e+03
Training complete.
Evaluating the model...


100%|██████████| 12/12 [12:36<00:00, 63.05s/it]
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name           | Type        | Params | Mode 
-------------------------------------------------------
0 | balancer       | Sequential  | 1.6 K  | train
1 | prior_encoder0 | Sequential  | 0      | train
2 | prior_encoder1 | Sequential  | 0      | train
3 | cond_mean0     | Sequential  | 701    | train
4 | cond_mean1     | Sequential  | 701    | train
5 | cond_std0      | Sequential  | 701    | train
6 | cond_std1      | Sequential  | 701    | train
7 | flows0         | DSFMarginal | 892    | train
8 | flows1         | DSFMarginal | 892    | train
-------------------------------------------------------
6.2 K     Trainable params
0         Non-trainable params
6.2 K     Total params
0.025     Total estimated model params size (MB)
45        Modules in train mode
0         Modules in eval mode


Fit and evaluate NOFLITE ...


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_steps=10000` reached.


Fit and evaluate BART ...


Multiprocess sampling (4 chains in 4 jobs)
PGBART: [mu]


Output()

Sampling 4 chains for 500 tune and 500 draw iterations (2_000 + 2_000 draws total) took 18 seconds.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
Sampling: [mu, y]


Output()

X_extended (3000, 11)


Sampling: [mu, y]


Output()

Sampling: [mu, y]


Output()

X_extended (3000, 11)


Sampling: [mu, y]


Output()

Fit and evaluate CMGP ...
Fit and evaluate CEVAE ...
Using cpu


100%|██████████| 7000/7000 [00:35<00:00, 198.25it/s]


Fit and evaluate FCCN ...


100%|██████████| 20000/20000 [01:59<00:00, 167.68it/s]


Fit and evaluate GANITE ...


Training Counterfactual GAN: 100%|██████████| 10000/10000 [07:57<00:00, 20.95it/s]
Training ITE GAN: 100%|██████████| 10000/10000 [08:53<00:00, 18.75it/s]


Fit and evaluate DKLITE ...


  5%|▌         | 51/1000 [00:02<00:48, 19.67it/s]


Fit and evaluate diffpo ...


100%|██████████| 7/7 [00:00<00:00, 18.82it/s, avg_epoch_loss=0, epoch=0]
100%|██████████| 7/7 [00:00<00:00, 20.57it/s, avg_epoch_loss=0, epoch=1]
100%|██████████| 7/7 [00:00<00:00, 20.32it/s, avg_epoch_loss=0, epoch=2]
100%|██████████| 7/7 [00:00<00:00, 20.06it/s, avg_epoch_loss=0, epoch=3]
100%|██████████| 7/7 [00:00<00:00, 19.07it/s, avg_epoch_loss=0, epoch=4]
100%|██████████| 7/7 [00:00<00:00, 19.56it/s, avg_epoch_loss=0, epoch=5]
100%|██████████| 7/7 [00:00<00:00, 19.23it/s, avg_epoch_loss=0, epoch=6]
100%|██████████| 7/7 [00:00<00:00, 19.29it/s, avg_epoch_loss=0, epoch=7]
100%|██████████| 7/7 [00:00<00:00, 19.03it/s, avg_epoch_loss=0, epoch=8]
100%|██████████| 7/7 [00:00<00:00, 19.37it/s, avg_epoch_loss=0, epoch=9]
100%|██████████| 7/7 [00:00<00:00, 18.96it/s, avg_epoch_loss=0, epoch=10]
100%|██████████| 7/7 [00:00<00:00, 18.79it/s, avg_epoch_loss=0, epoch=11]
100%|██████████| 7/7 [00:00<00:00, 16.11it/s, avg_epoch_loss=0, epoch=12]
100%|██████████| 7/7 [00:00<00:00, 14.10it/s, av

Start validation!!!
Epoch: 49


100%|██████████| 2/2 [00:49<00:00, 24.89s/it]


##### End evaluation!!
PEHE VAL = 8.1e+03


100%|██████████| 7/7 [00:00<00:00, 18.64it/s, avg_epoch_loss=0, epoch=50]
100%|██████████| 7/7 [00:00<00:00, 18.36it/s, avg_epoch_loss=0, epoch=51]
100%|██████████| 7/7 [00:00<00:00, 18.62it/s, avg_epoch_loss=0, epoch=52]
100%|██████████| 7/7 [00:00<00:00, 18.85it/s, avg_epoch_loss=0, epoch=53]
100%|██████████| 7/7 [00:00<00:00, 19.32it/s, avg_epoch_loss=0, epoch=54]
100%|██████████| 7/7 [00:00<00:00, 19.32it/s, avg_epoch_loss=0, epoch=55]
100%|██████████| 7/7 [00:00<00:00, 18.65it/s, avg_epoch_loss=0, epoch=56]
100%|██████████| 7/7 [00:00<00:00, 13.95it/s, avg_epoch_loss=0, epoch=57]
100%|██████████| 7/7 [00:00<00:00, 17.93it/s, avg_epoch_loss=0, epoch=58]
100%|██████████| 7/7 [00:00<00:00, 18.37it/s, avg_epoch_loss=0, epoch=59]
100%|██████████| 7/7 [00:00<00:00, 19.01it/s, avg_epoch_loss=0, epoch=60]
100%|██████████| 7/7 [00:00<00:00, 18.71it/s, avg_epoch_loss=0, epoch=61]
100%|██████████| 7/7 [00:00<00:00, 18.94it/s, avg_epoch_loss=0, epoch=62]
100%|██████████| 7/7 [00:00<00:00, 19.

Start validation!!!
Epoch: 99


100%|██████████| 2/2 [00:50<00:00, 25.21s/it]


##### End evaluation!!
PEHE VAL = 8.29e+03


100%|██████████| 7/7 [00:00<00:00, 19.43it/s, avg_epoch_loss=0, epoch=100]
100%|██████████| 7/7 [00:00<00:00, 19.21it/s, avg_epoch_loss=0, epoch=101]
100%|██████████| 7/7 [00:00<00:00, 19.28it/s, avg_epoch_loss=0, epoch=102]
100%|██████████| 7/7 [00:00<00:00, 14.92it/s, avg_epoch_loss=0, epoch=103]
100%|██████████| 7/7 [00:00<00:00, 17.17it/s, avg_epoch_loss=0, epoch=104]
100%|██████████| 7/7 [00:00<00:00, 17.64it/s, avg_epoch_loss=0, epoch=105]
100%|██████████| 7/7 [00:00<00:00, 18.68it/s, avg_epoch_loss=0, epoch=106]
100%|██████████| 7/7 [00:00<00:00, 18.31it/s, avg_epoch_loss=0, epoch=107]
100%|██████████| 7/7 [00:00<00:00, 17.48it/s, avg_epoch_loss=0, epoch=108]
100%|██████████| 7/7 [00:00<00:00, 17.88it/s, avg_epoch_loss=0, epoch=109]
100%|██████████| 7/7 [00:00<00:00, 18.27it/s, avg_epoch_loss=0, epoch=110]
100%|██████████| 7/7 [00:00<00:00, 18.73it/s, avg_epoch_loss=0, epoch=111]
100%|██████████| 7/7 [00:00<00:00, 19.34it/s, avg_epoch_loss=0, epoch=112]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 149


100%|██████████| 2/2 [00:49<00:00, 24.98s/it]


##### End evaluation!!
PEHE VAL = 8.48e+03


100%|██████████| 7/7 [00:00<00:00, 19.43it/s, avg_epoch_loss=0, epoch=150]
100%|██████████| 7/7 [00:00<00:00, 19.49it/s, avg_epoch_loss=0, epoch=151]
100%|██████████| 7/7 [00:00<00:00, 18.88it/s, avg_epoch_loss=0, epoch=152]
100%|██████████| 7/7 [00:00<00:00, 19.25it/s, avg_epoch_loss=0, epoch=153]
100%|██████████| 7/7 [00:00<00:00, 19.24it/s, avg_epoch_loss=0, epoch=154]
100%|██████████| 7/7 [00:00<00:00, 19.23it/s, avg_epoch_loss=0, epoch=155]
100%|██████████| 7/7 [00:00<00:00, 19.04it/s, avg_epoch_loss=0, epoch=156]
100%|██████████| 7/7 [00:00<00:00, 19.21it/s, avg_epoch_loss=0, epoch=157]
100%|██████████| 7/7 [00:00<00:00, 18.57it/s, avg_epoch_loss=0, epoch=158]
100%|██████████| 7/7 [00:00<00:00, 15.63it/s, avg_epoch_loss=0, epoch=159]
100%|██████████| 7/7 [00:00<00:00, 17.50it/s, avg_epoch_loss=0, epoch=160]
100%|██████████| 7/7 [00:00<00:00, 17.69it/s, avg_epoch_loss=0, epoch=161]
100%|██████████| 7/7 [00:00<00:00, 17.80it/s, avg_epoch_loss=0, epoch=162]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 199


100%|██████████| 2/2 [00:51<00:00, 25.57s/it]


##### End evaluation!!
PEHE VAL = 8.64e+03


100%|██████████| 7/7 [00:00<00:00, 15.35it/s, avg_epoch_loss=0, epoch=200]
100%|██████████| 7/7 [00:00<00:00, 14.65it/s, avg_epoch_loss=0, epoch=201]
100%|██████████| 7/7 [00:00<00:00, 17.40it/s, avg_epoch_loss=0, epoch=202]
100%|██████████| 7/7 [00:00<00:00, 18.77it/s, avg_epoch_loss=0, epoch=203]
100%|██████████| 7/7 [00:00<00:00, 14.62it/s, avg_epoch_loss=0, epoch=204]
100%|██████████| 7/7 [00:00<00:00, 19.02it/s, avg_epoch_loss=0, epoch=205]
100%|██████████| 7/7 [00:00<00:00, 18.76it/s, avg_epoch_loss=0, epoch=206]
100%|██████████| 7/7 [00:00<00:00, 19.42it/s, avg_epoch_loss=0, epoch=207]
100%|██████████| 7/7 [00:00<00:00, 19.37it/s, avg_epoch_loss=0, epoch=208]
100%|██████████| 7/7 [00:00<00:00, 19.28it/s, avg_epoch_loss=0, epoch=209]
100%|██████████| 7/7 [00:00<00:00, 19.18it/s, avg_epoch_loss=0, epoch=210]
100%|██████████| 7/7 [00:00<00:00, 15.72it/s, avg_epoch_loss=0, epoch=211]
100%|██████████| 7/7 [00:00<00:00, 19.04it/s, avg_epoch_loss=0, epoch=212]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 249


100%|██████████| 2/2 [00:50<00:00, 25.13s/it]


##### End evaluation!!
PEHE VAL = 8.31e+03


100%|██████████| 7/7 [00:00<00:00, 16.42it/s, avg_epoch_loss=0, epoch=250]
100%|██████████| 7/7 [00:00<00:00, 18.96it/s, avg_epoch_loss=0, epoch=251]
100%|██████████| 7/7 [00:00<00:00, 19.34it/s, avg_epoch_loss=0, epoch=252]
100%|██████████| 7/7 [00:00<00:00, 19.06it/s, avg_epoch_loss=0, epoch=253]
100%|██████████| 7/7 [00:00<00:00, 18.79it/s, avg_epoch_loss=0, epoch=254]
100%|██████████| 7/7 [00:00<00:00, 18.76it/s, avg_epoch_loss=0, epoch=255]
100%|██████████| 7/7 [00:00<00:00, 18.75it/s, avg_epoch_loss=0, epoch=256]
100%|██████████| 7/7 [00:00<00:00, 19.04it/s, avg_epoch_loss=0, epoch=257]
100%|██████████| 7/7 [00:00<00:00, 19.48it/s, avg_epoch_loss=0, epoch=258]
100%|██████████| 7/7 [00:00<00:00, 18.68it/s, avg_epoch_loss=0, epoch=259]
100%|██████████| 7/7 [00:00<00:00, 18.79it/s, avg_epoch_loss=0, epoch=260]
100%|██████████| 7/7 [00:00<00:00, 19.08it/s, avg_epoch_loss=0, epoch=261]
100%|██████████| 7/7 [00:00<00:00, 19.38it/s, avg_epoch_loss=0, epoch=262]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 299


100%|██████████| 2/2 [00:49<00:00, 24.88s/it]


##### End evaluation!!
PEHE VAL = 8.37e+03


100%|██████████| 7/7 [00:00<00:00, 18.90it/s, avg_epoch_loss=0, epoch=300]
100%|██████████| 7/7 [00:00<00:00, 15.41it/s, avg_epoch_loss=0, epoch=301]
100%|██████████| 7/7 [00:00<00:00, 18.03it/s, avg_epoch_loss=0, epoch=302]
100%|██████████| 7/7 [00:00<00:00, 18.12it/s, avg_epoch_loss=0, epoch=303]
100%|██████████| 7/7 [00:00<00:00, 17.94it/s, avg_epoch_loss=0, epoch=304]
100%|██████████| 7/7 [00:00<00:00, 17.97it/s, avg_epoch_loss=0, epoch=305]
100%|██████████| 7/7 [00:00<00:00, 17.31it/s, avg_epoch_loss=0, epoch=306]
100%|██████████| 7/7 [00:00<00:00, 19.08it/s, avg_epoch_loss=0, epoch=307]
100%|██████████| 7/7 [00:00<00:00, 19.29it/s, avg_epoch_loss=0, epoch=308]
100%|██████████| 7/7 [00:00<00:00, 18.07it/s, avg_epoch_loss=0, epoch=309]
100%|██████████| 7/7 [00:00<00:00, 18.85it/s, avg_epoch_loss=0, epoch=310]
100%|██████████| 7/7 [00:00<00:00, 19.07it/s, avg_epoch_loss=0, epoch=311]
100%|██████████| 7/7 [00:00<00:00, 17.56it/s, avg_epoch_loss=0, epoch=312]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 349


100%|██████████| 2/2 [00:50<00:00, 25.30s/it]


##### End evaluation!!
PEHE VAL = 8.91e+03


100%|██████████| 7/7 [00:00<00:00, 14.47it/s, avg_epoch_loss=0, epoch=350]
100%|██████████| 7/7 [00:00<00:00, 15.27it/s, avg_epoch_loss=0, epoch=351]
100%|██████████| 7/7 [00:00<00:00, 14.85it/s, avg_epoch_loss=0, epoch=352]
100%|██████████| 7/7 [00:00<00:00, 14.35it/s, avg_epoch_loss=0, epoch=353]
100%|██████████| 7/7 [00:00<00:00, 19.38it/s, avg_epoch_loss=0, epoch=354]
100%|██████████| 7/7 [00:00<00:00, 18.79it/s, avg_epoch_loss=0, epoch=355]
100%|██████████| 7/7 [00:00<00:00, 18.03it/s, avg_epoch_loss=0, epoch=356]
100%|██████████| 7/7 [00:00<00:00, 19.03it/s, avg_epoch_loss=0, epoch=357]
100%|██████████| 7/7 [00:00<00:00, 19.14it/s, avg_epoch_loss=0, epoch=358]
100%|██████████| 7/7 [00:00<00:00, 19.36it/s, avg_epoch_loss=0, epoch=359]
100%|██████████| 7/7 [00:00<00:00, 19.17it/s, avg_epoch_loss=0, epoch=360]
100%|██████████| 7/7 [00:00<00:00, 14.42it/s, avg_epoch_loss=0, epoch=361]
100%|██████████| 7/7 [00:00<00:00, 18.86it/s, avg_epoch_loss=0, epoch=362]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 399


100%|██████████| 2/2 [00:51<00:00, 25.92s/it]


##### End evaluation!!
PEHE VAL = 8.9e+03


100%|██████████| 7/7 [00:00<00:00, 18.73it/s, avg_epoch_loss=0, epoch=400]
100%|██████████| 7/7 [00:00<00:00, 16.14it/s, avg_epoch_loss=0, epoch=401]
100%|██████████| 7/7 [00:00<00:00, 18.78it/s, avg_epoch_loss=0, epoch=402]
100%|██████████| 7/7 [00:00<00:00, 18.35it/s, avg_epoch_loss=0, epoch=403]
100%|██████████| 7/7 [00:00<00:00, 18.53it/s, avg_epoch_loss=0, epoch=404]
100%|██████████| 7/7 [00:00<00:00, 18.68it/s, avg_epoch_loss=0, epoch=405]
100%|██████████| 7/7 [00:00<00:00, 16.96it/s, avg_epoch_loss=0, epoch=406]
100%|██████████| 7/7 [00:00<00:00, 16.81it/s, avg_epoch_loss=0, epoch=407]
100%|██████████| 7/7 [00:00<00:00, 18.04it/s, avg_epoch_loss=0, epoch=408]
100%|██████████| 7/7 [00:00<00:00, 18.89it/s, avg_epoch_loss=0, epoch=409]
100%|██████████| 7/7 [00:00<00:00, 18.68it/s, avg_epoch_loss=0, epoch=410]
100%|██████████| 7/7 [00:00<00:00, 18.89it/s, avg_epoch_loss=0, epoch=411]
100%|██████████| 7/7 [00:00<00:00, 18.69it/s, avg_epoch_loss=0, epoch=412]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 449


100%|██████████| 2/2 [00:51<00:00, 25.67s/it]


##### End evaluation!!
PEHE VAL = 8.03e+03


100%|██████████| 7/7 [00:00<00:00, 14.40it/s, avg_epoch_loss=0, epoch=450]
100%|██████████| 7/7 [00:00<00:00, 17.42it/s, avg_epoch_loss=0, epoch=451]
100%|██████████| 7/7 [00:00<00:00, 17.39it/s, avg_epoch_loss=0, epoch=452]
100%|██████████| 7/7 [00:00<00:00, 17.52it/s, avg_epoch_loss=0, epoch=453]
100%|██████████| 7/7 [00:00<00:00, 17.50it/s, avg_epoch_loss=0, epoch=454]
100%|██████████| 7/7 [00:00<00:00, 17.50it/s, avg_epoch_loss=0, epoch=455]
100%|██████████| 7/7 [00:00<00:00, 17.27it/s, avg_epoch_loss=0, epoch=456]
100%|██████████| 7/7 [00:00<00:00, 18.32it/s, avg_epoch_loss=0, epoch=457]
100%|██████████| 7/7 [00:00<00:00, 17.82it/s, avg_epoch_loss=0, epoch=458]
100%|██████████| 7/7 [00:00<00:00, 16.24it/s, avg_epoch_loss=0, epoch=459]
100%|██████████| 7/7 [00:00<00:00, 18.18it/s, avg_epoch_loss=0, epoch=460]
100%|██████████| 7/7 [00:00<00:00, 18.33it/s, avg_epoch_loss=0, epoch=461]
100%|██████████| 7/7 [00:00<00:00, 18.78it/s, avg_epoch_loss=0, epoch=462]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 499


100%|██████████| 2/2 [00:52<00:00, 26.12s/it]


##### End evaluation!!
PEHE VAL = 8.54e+03
Training complete.
Evaluating the model...


100%|██████████| 12/12 [11:48<00:00, 59.07s/it]
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name           | Type        | Params | Mode 
-------------------------------------------------------
0 | balancer       | Sequential  | 1.6 K  | train
1 | prior_encoder0 | Sequential  | 0      | train
2 | prior_encoder1 | Sequential  | 0      | train
3 | cond_mean0     | Sequential  | 701    | train
4 | cond_mean1     | Sequential  | 701    | train
5 | cond_std0      | Sequential  | 701    | train
6 | cond_std1      | Sequential  | 701    | train
7 | flows0         | DSFMarginal | 892    | train
8 | flows1         | DSFMarginal | 892    | train
-------------------------------------------------------
6.2 K     Trainable params
0         Non-trainable params
6.2 K     Total params
0.025     Total estimated model params size (MB)
45        Modules in train mode
0         Modules in eval mode


Fit and evaluate NOFLITE ...


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_steps=10000` reached.


Fit and evaluate BART ...


Multiprocess sampling (4 chains in 4 jobs)
PGBART: [mu]


Output()

Sampling 4 chains for 500 tune and 500 draw iterations (2_000 + 2_000 draws total) took 18 seconds.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
Sampling: [mu, y]


Output()

X_extended (3000, 11)


Sampling: [mu, y]


Output()

Sampling: [mu, y]


Output()

X_extended (3000, 11)


Sampling: [mu, y]


Output()

Fit and evaluate CMGP ...
Fit and evaluate CEVAE ...
Using cpu


100%|██████████| 7000/7000 [00:32<00:00, 217.20it/s]


Fit and evaluate FCCN ...


100%|██████████| 20000/20000 [02:05<00:00, 159.24it/s]


Fit and evaluate GANITE ...


Training Counterfactual GAN: 100%|██████████| 10000/10000 [08:40<00:00, 19.22it/s]
Training ITE GAN: 100%|██████████| 10000/10000 [09:36<00:00, 17.36it/s]


Fit and evaluate DKLITE ...


  5%|▌         | 51/1000 [00:02<00:50, 18.98it/s]


Fit and evaluate diffpo ...


100%|██████████| 7/7 [00:00<00:00, 15.21it/s, avg_epoch_loss=0, epoch=0]
100%|██████████| 7/7 [00:00<00:00, 15.97it/s, avg_epoch_loss=0, epoch=1]
100%|██████████| 7/7 [00:00<00:00, 17.33it/s, avg_epoch_loss=0, epoch=2]
100%|██████████| 7/7 [00:00<00:00, 17.35it/s, avg_epoch_loss=0, epoch=3]
100%|██████████| 7/7 [00:00<00:00, 18.62it/s, avg_epoch_loss=0, epoch=4]
100%|██████████| 7/7 [00:00<00:00, 17.73it/s, avg_epoch_loss=0, epoch=5]
100%|██████████| 7/7 [00:00<00:00, 18.18it/s, avg_epoch_loss=0, epoch=6]
100%|██████████| 7/7 [00:00<00:00, 18.45it/s, avg_epoch_loss=0, epoch=7]
100%|██████████| 7/7 [00:00<00:00, 17.69it/s, avg_epoch_loss=0, epoch=8]
100%|██████████| 7/7 [00:00<00:00, 17.12it/s, avg_epoch_loss=0, epoch=9]
100%|██████████| 7/7 [00:00<00:00, 16.36it/s, avg_epoch_loss=0, epoch=10]
100%|██████████| 7/7 [00:00<00:00, 16.62it/s, avg_epoch_loss=0, epoch=11]
100%|██████████| 7/7 [00:00<00:00, 16.07it/s, avg_epoch_loss=0, epoch=12]
100%|██████████| 7/7 [00:00<00:00, 15.34it/s, av

Start validation!!!
Epoch: 49


100%|██████████| 2/2 [00:56<00:00, 28.13s/it]


##### End evaluation!!
PEHE VAL = 8.1e+03


100%|██████████| 7/7 [00:00<00:00, 14.81it/s, avg_epoch_loss=0, epoch=50]
100%|██████████| 7/7 [00:00<00:00, 18.04it/s, avg_epoch_loss=0, epoch=51]
100%|██████████| 7/7 [00:00<00:00, 18.53it/s, avg_epoch_loss=0, epoch=52]
100%|██████████| 7/7 [00:00<00:00, 16.06it/s, avg_epoch_loss=0, epoch=53]
100%|██████████| 7/7 [00:00<00:00, 18.54it/s, avg_epoch_loss=0, epoch=54]
100%|██████████| 7/7 [00:00<00:00, 17.36it/s, avg_epoch_loss=0, epoch=55]
100%|██████████| 7/7 [00:00<00:00, 18.28it/s, avg_epoch_loss=0, epoch=56]
100%|██████████| 7/7 [00:00<00:00, 18.39it/s, avg_epoch_loss=0, epoch=57]
100%|██████████| 7/7 [00:00<00:00, 18.76it/s, avg_epoch_loss=0, epoch=58]
100%|██████████| 7/7 [00:00<00:00, 18.78it/s, avg_epoch_loss=0, epoch=59]
100%|██████████| 7/7 [00:00<00:00, 18.75it/s, avg_epoch_loss=0, epoch=60]
100%|██████████| 7/7 [00:00<00:00, 17.27it/s, avg_epoch_loss=0, epoch=61]
100%|██████████| 7/7 [00:00<00:00, 17.20it/s, avg_epoch_loss=0, epoch=62]
100%|██████████| 7/7 [00:00<00:00, 16.

Start validation!!!
Epoch: 99


100%|██████████| 2/2 [00:55<00:00, 27.82s/it]


##### End evaluation!!
PEHE VAL = 8.29e+03


100%|██████████| 7/7 [00:00<00:00, 18.38it/s, avg_epoch_loss=0, epoch=100]
100%|██████████| 7/7 [00:00<00:00, 18.66it/s, avg_epoch_loss=0, epoch=101]
100%|██████████| 7/7 [00:00<00:00, 19.00it/s, avg_epoch_loss=0, epoch=102]
100%|██████████| 7/7 [00:00<00:00, 20.12it/s, avg_epoch_loss=0, epoch=103]
100%|██████████| 7/7 [00:00<00:00, 18.57it/s, avg_epoch_loss=0, epoch=104]
100%|██████████| 7/7 [00:00<00:00, 17.58it/s, avg_epoch_loss=0, epoch=105]
100%|██████████| 7/7 [00:00<00:00, 17.27it/s, avg_epoch_loss=0, epoch=106]
100%|██████████| 7/7 [00:00<00:00, 17.95it/s, avg_epoch_loss=0, epoch=107]
100%|██████████| 7/7 [00:00<00:00, 18.06it/s, avg_epoch_loss=0, epoch=108]
100%|██████████| 7/7 [00:00<00:00, 18.28it/s, avg_epoch_loss=0, epoch=109]
100%|██████████| 7/7 [00:00<00:00, 18.06it/s, avg_epoch_loss=0, epoch=110]
100%|██████████| 7/7 [00:00<00:00, 15.22it/s, avg_epoch_loss=0, epoch=111]
100%|██████████| 7/7 [00:00<00:00, 16.38it/s, avg_epoch_loss=0, epoch=112]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 149


100%|██████████| 2/2 [00:55<00:00, 27.97s/it]


##### End evaluation!!
PEHE VAL = 8.48e+03


100%|██████████| 7/7 [00:00<00:00, 17.93it/s, avg_epoch_loss=0, epoch=150]
100%|██████████| 7/7 [00:00<00:00, 17.82it/s, avg_epoch_loss=0, epoch=151]
100%|██████████| 7/7 [00:00<00:00, 16.25it/s, avg_epoch_loss=0, epoch=152]
100%|██████████| 7/7 [00:00<00:00, 16.15it/s, avg_epoch_loss=0, epoch=153]
100%|██████████| 7/7 [00:00<00:00, 16.29it/s, avg_epoch_loss=0, epoch=154]
100%|██████████| 7/7 [00:00<00:00, 15.99it/s, avg_epoch_loss=0, epoch=155]
100%|██████████| 7/7 [00:00<00:00, 16.07it/s, avg_epoch_loss=0, epoch=156]
100%|██████████| 7/7 [00:00<00:00, 18.83it/s, avg_epoch_loss=0, epoch=157]
100%|██████████| 7/7 [00:00<00:00, 17.27it/s, avg_epoch_loss=0, epoch=158]
100%|██████████| 7/7 [00:00<00:00, 18.14it/s, avg_epoch_loss=0, epoch=159]
100%|██████████| 7/7 [00:00<00:00, 17.60it/s, avg_epoch_loss=0, epoch=160]
100%|██████████| 7/7 [00:00<00:00, 17.74it/s, avg_epoch_loss=0, epoch=161]
100%|██████████| 7/7 [00:00<00:00, 18.28it/s, avg_epoch_loss=0, epoch=162]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 199


100%|██████████| 2/2 [00:55<00:00, 27.67s/it]


##### End evaluation!!
PEHE VAL = 8.64e+03


100%|██████████| 7/7 [00:00<00:00, 19.34it/s, avg_epoch_loss=0, epoch=200]
100%|██████████| 7/7 [00:00<00:00, 17.30it/s, avg_epoch_loss=0, epoch=201]
100%|██████████| 7/7 [00:00<00:00, 18.23it/s, avg_epoch_loss=0, epoch=202]
100%|██████████| 7/7 [00:00<00:00, 13.75it/s, avg_epoch_loss=0, epoch=203]
100%|██████████| 7/7 [00:00<00:00, 16.10it/s, avg_epoch_loss=0, epoch=204]
100%|██████████| 7/7 [00:00<00:00, 15.46it/s, avg_epoch_loss=0, epoch=205]
100%|██████████| 7/7 [00:00<00:00, 18.66it/s, avg_epoch_loss=0, epoch=206]
100%|██████████| 7/7 [00:00<00:00, 17.75it/s, avg_epoch_loss=0, epoch=207]
100%|██████████| 7/7 [00:00<00:00, 16.92it/s, avg_epoch_loss=0, epoch=208]
100%|██████████| 7/7 [00:00<00:00, 16.06it/s, avg_epoch_loss=0, epoch=209]
100%|██████████| 7/7 [00:00<00:00, 19.00it/s, avg_epoch_loss=0, epoch=210]
100%|██████████| 7/7 [00:00<00:00, 18.96it/s, avg_epoch_loss=0, epoch=211]
100%|██████████| 7/7 [00:00<00:00, 18.18it/s, avg_epoch_loss=0, epoch=212]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 249


100%|██████████| 2/2 [00:55<00:00, 27.89s/it]


##### End evaluation!!
PEHE VAL = 8.31e+03


100%|██████████| 7/7 [00:00<00:00, 18.41it/s, avg_epoch_loss=0, epoch=250]
100%|██████████| 7/7 [00:00<00:00, 15.13it/s, avg_epoch_loss=0, epoch=251]
100%|██████████| 7/7 [00:00<00:00, 17.61it/s, avg_epoch_loss=0, epoch=252]
100%|██████████| 7/7 [00:00<00:00, 17.96it/s, avg_epoch_loss=0, epoch=253]
100%|██████████| 7/7 [00:00<00:00, 17.63it/s, avg_epoch_loss=0, epoch=254]
100%|██████████| 7/7 [00:00<00:00, 18.05it/s, avg_epoch_loss=0, epoch=255]
100%|██████████| 7/7 [00:00<00:00, 15.31it/s, avg_epoch_loss=0, epoch=256]
100%|██████████| 7/7 [00:00<00:00, 17.53it/s, avg_epoch_loss=0, epoch=257]
100%|██████████| 7/7 [00:00<00:00, 18.45it/s, avg_epoch_loss=0, epoch=258]
100%|██████████| 7/7 [00:00<00:00, 19.11it/s, avg_epoch_loss=0, epoch=259]
100%|██████████| 7/7 [00:00<00:00, 19.00it/s, avg_epoch_loss=0, epoch=260]
100%|██████████| 7/7 [00:00<00:00, 20.02it/s, avg_epoch_loss=0, epoch=261]
100%|██████████| 7/7 [00:00<00:00, 18.69it/s, avg_epoch_loss=0, epoch=262]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 299


100%|██████████| 2/2 [00:57<00:00, 28.61s/it]


##### End evaluation!!
PEHE VAL = 8.37e+03


100%|██████████| 7/7 [00:00<00:00, 18.24it/s, avg_epoch_loss=0, epoch=300]
100%|██████████| 7/7 [00:00<00:00, 18.59it/s, avg_epoch_loss=0, epoch=301]
100%|██████████| 7/7 [00:00<00:00, 18.82it/s, avg_epoch_loss=0, epoch=302]
100%|██████████| 7/7 [00:00<00:00, 18.63it/s, avg_epoch_loss=0, epoch=303]
100%|██████████| 7/7 [00:00<00:00, 20.22it/s, avg_epoch_loss=0, epoch=304]
100%|██████████| 7/7 [00:00<00:00, 16.77it/s, avg_epoch_loss=0, epoch=305]
100%|██████████| 7/7 [00:00<00:00, 14.04it/s, avg_epoch_loss=0, epoch=306]
100%|██████████| 7/7 [00:00<00:00, 18.73it/s, avg_epoch_loss=0, epoch=307]
100%|██████████| 7/7 [00:00<00:00, 18.29it/s, avg_epoch_loss=0, epoch=308]
100%|██████████| 7/7 [00:00<00:00, 17.66it/s, avg_epoch_loss=0, epoch=309]
100%|██████████| 7/7 [00:00<00:00, 18.18it/s, avg_epoch_loss=0, epoch=310]
100%|██████████| 7/7 [00:00<00:00, 18.34it/s, avg_epoch_loss=0, epoch=311]
100%|██████████| 7/7 [00:00<00:00, 17.65it/s, avg_epoch_loss=0, epoch=312]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 349


100%|██████████| 2/2 [00:56<00:00, 28.14s/it]


##### End evaluation!!
PEHE VAL = 8.91e+03


100%|██████████| 7/7 [00:00<00:00, 17.66it/s, avg_epoch_loss=0, epoch=350]
100%|██████████| 7/7 [00:00<00:00, 16.34it/s, avg_epoch_loss=0, epoch=351]
100%|██████████| 7/7 [00:00<00:00, 16.64it/s, avg_epoch_loss=0, epoch=352]
100%|██████████| 7/7 [00:00<00:00, 18.86it/s, avg_epoch_loss=0, epoch=353]
100%|██████████| 7/7 [00:00<00:00, 18.95it/s, avg_epoch_loss=0, epoch=354]
100%|██████████| 7/7 [00:00<00:00, 19.14it/s, avg_epoch_loss=0, epoch=355]
100%|██████████| 7/7 [00:00<00:00, 18.49it/s, avg_epoch_loss=0, epoch=356]
100%|██████████| 7/7 [00:00<00:00, 17.18it/s, avg_epoch_loss=0, epoch=357]
100%|██████████| 7/7 [00:00<00:00, 17.24it/s, avg_epoch_loss=0, epoch=358]
100%|██████████| 7/7 [00:00<00:00, 16.55it/s, avg_epoch_loss=0, epoch=359]
100%|██████████| 7/7 [00:00<00:00, 17.12it/s, avg_epoch_loss=0, epoch=360]
100%|██████████| 7/7 [00:00<00:00, 17.14it/s, avg_epoch_loss=0, epoch=361]
100%|██████████| 7/7 [00:00<00:00, 17.33it/s, avg_epoch_loss=0, epoch=362]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 399


100%|██████████| 2/2 [00:55<00:00, 27.64s/it]


##### End evaluation!!
PEHE VAL = 8.9e+03


100%|██████████| 7/7 [00:00<00:00, 17.78it/s, avg_epoch_loss=0, epoch=400]
100%|██████████| 7/7 [00:00<00:00, 16.12it/s, avg_epoch_loss=0, epoch=401]
100%|██████████| 7/7 [00:00<00:00, 16.94it/s, avg_epoch_loss=0, epoch=402]
100%|██████████| 7/7 [00:00<00:00, 18.59it/s, avg_epoch_loss=0, epoch=403]
100%|██████████| 7/7 [00:00<00:00, 18.71it/s, avg_epoch_loss=0, epoch=404]
100%|██████████| 7/7 [00:00<00:00, 19.00it/s, avg_epoch_loss=0, epoch=405]
100%|██████████| 7/7 [00:00<00:00, 18.48it/s, avg_epoch_loss=0, epoch=406]
100%|██████████| 7/7 [00:00<00:00, 18.81it/s, avg_epoch_loss=0, epoch=407]
100%|██████████| 7/7 [00:00<00:00, 18.70it/s, avg_epoch_loss=0, epoch=408]
100%|██████████| 7/7 [00:00<00:00, 17.01it/s, avg_epoch_loss=0, epoch=409]
100%|██████████| 7/7 [00:00<00:00, 16.31it/s, avg_epoch_loss=0, epoch=410]
100%|██████████| 7/7 [00:00<00:00, 16.00it/s, avg_epoch_loss=0, epoch=411]
100%|██████████| 7/7 [00:00<00:00, 18.09it/s, avg_epoch_loss=0, epoch=412]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 449


100%|██████████| 2/2 [00:53<00:00, 26.71s/it]


##### End evaluation!!
PEHE VAL = 8.03e+03


100%|██████████| 7/7 [00:00<00:00, 18.83it/s, avg_epoch_loss=0, epoch=450]
100%|██████████| 7/7 [00:00<00:00, 18.50it/s, avg_epoch_loss=0, epoch=451]
100%|██████████| 7/7 [00:00<00:00, 19.00it/s, avg_epoch_loss=0, epoch=452]
100%|██████████| 7/7 [00:00<00:00, 17.94it/s, avg_epoch_loss=0, epoch=453]
100%|██████████| 7/7 [00:00<00:00, 17.45it/s, avg_epoch_loss=0, epoch=454]
100%|██████████| 7/7 [00:00<00:00, 17.56it/s, avg_epoch_loss=0, epoch=455]
100%|██████████| 7/7 [00:00<00:00, 14.76it/s, avg_epoch_loss=0, epoch=456]
100%|██████████| 7/7 [00:00<00:00, 19.27it/s, avg_epoch_loss=0, epoch=457]
100%|██████████| 7/7 [00:00<00:00, 18.60it/s, avg_epoch_loss=0, epoch=458]
100%|██████████| 7/7 [00:00<00:00, 18.84it/s, avg_epoch_loss=0, epoch=459]
100%|██████████| 7/7 [00:00<00:00, 19.03it/s, avg_epoch_loss=0, epoch=460]
100%|██████████| 7/7 [00:00<00:00, 18.71it/s, avg_epoch_loss=0, epoch=461]
100%|██████████| 7/7 [00:00<00:00, 17.62it/s, avg_epoch_loss=0, epoch=462]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 499


100%|██████████| 2/2 [00:53<00:00, 26.86s/it]


##### End evaluation!!
PEHE VAL = 8.54e+03
Training complete.
Evaluating the model...


100%|██████████| 12/12 [12:02<00:00, 60.22s/it]
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name           | Type        | Params | Mode 
-------------------------------------------------------
0 | balancer       | Sequential  | 1.6 K  | train
1 | prior_encoder0 | Sequential  | 0      | train
2 | prior_encoder1 | Sequential  | 0      | train
3 | cond_mean0     | Sequential  | 701    | train
4 | cond_mean1     | Sequential  | 701    | train
5 | cond_std0      | Sequential  | 701    | train
6 | cond_std1      | Sequential  | 701    | train
7 | flows0         | DSFMarginal | 892    | train
8 | flows1         | DSFMarginal | 892    | train
-------------------------------------------------------
6.2 K     Trainable params
0         Non-trainable params
6.2 K     Total params
0.025     Total estimated model params size (MB)
45        Modules in train mode
0         Modules in eval mode


Fit and evaluate NOFLITE ...


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_steps=10000` reached.


Fit and evaluate BART ...


Multiprocess sampling (4 chains in 4 jobs)
PGBART: [mu]


Output()

Sampling 4 chains for 500 tune and 500 draw iterations (2_000 + 2_000 draws total) took 18 seconds.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
Sampling: [mu, y]


Output()

X_extended (3000, 11)


Sampling: [mu, y]


Output()

Sampling: [mu, y]


Output()

X_extended (3000, 11)


Sampling: [mu, y]


Output()

Fit and evaluate CMGP ...
Fit and evaluate CEVAE ...
Using cpu


100%|██████████| 7000/7000 [00:31<00:00, 219.11it/s]


Fit and evaluate FCCN ...


100%|██████████| 20000/20000 [02:25<00:00, 137.29it/s]


Fit and evaluate GANITE ...


Training Counterfactual GAN: 100%|██████████| 10000/10000 [07:46<00:00, 21.42it/s]
Training ITE GAN: 100%|██████████| 10000/10000 [08:55<00:00, 18.67it/s]


Fit and evaluate DKLITE ...


  5%|▌         | 51/1000 [00:02<00:50, 18.81it/s]


Fit and evaluate diffpo ...


100%|██████████| 7/7 [00:00<00:00, 13.51it/s, avg_epoch_loss=0, epoch=0]
100%|██████████| 7/7 [00:00<00:00, 14.62it/s, avg_epoch_loss=0, epoch=1]
100%|██████████| 7/7 [00:00<00:00, 16.40it/s, avg_epoch_loss=0, epoch=2]
100%|██████████| 7/7 [00:00<00:00, 17.74it/s, avg_epoch_loss=0, epoch=3]
100%|██████████| 7/7 [00:00<00:00, 17.92it/s, avg_epoch_loss=0, epoch=4]
100%|██████████| 7/7 [00:00<00:00, 17.42it/s, avg_epoch_loss=0, epoch=5]
100%|██████████| 7/7 [00:00<00:00, 17.26it/s, avg_epoch_loss=0, epoch=6]
100%|██████████| 7/7 [00:00<00:00, 11.29it/s, avg_epoch_loss=0, epoch=7]
100%|██████████| 7/7 [00:00<00:00, 18.06it/s, avg_epoch_loss=0, epoch=8]
100%|██████████| 7/7 [00:00<00:00, 17.89it/s, avg_epoch_loss=0, epoch=9]
100%|██████████| 7/7 [00:00<00:00, 17.90it/s, avg_epoch_loss=0, epoch=10]
100%|██████████| 7/7 [00:00<00:00, 18.35it/s, avg_epoch_loss=0, epoch=11]
100%|██████████| 7/7 [00:00<00:00, 15.85it/s, avg_epoch_loss=0, epoch=12]
100%|██████████| 7/7 [00:00<00:00, 17.14it/s, av

Start validation!!!
Epoch: 49


100%|██████████| 2/2 [01:02<00:00, 31.44s/it]


##### End evaluation!!
PEHE VAL = 8.1e+03


100%|██████████| 7/7 [00:00<00:00, 14.89it/s, avg_epoch_loss=0, epoch=50]
100%|██████████| 7/7 [00:00<00:00, 17.39it/s, avg_epoch_loss=0, epoch=51]
100%|██████████| 7/7 [00:00<00:00, 17.61it/s, avg_epoch_loss=0, epoch=52]
100%|██████████| 7/7 [00:00<00:00, 17.63it/s, avg_epoch_loss=0, epoch=53]
100%|██████████| 7/7 [00:00<00:00, 17.86it/s, avg_epoch_loss=0, epoch=54]
100%|██████████| 7/7 [00:00<00:00, 18.23it/s, avg_epoch_loss=0, epoch=55]
100%|██████████| 7/7 [00:00<00:00, 18.34it/s, avg_epoch_loss=0, epoch=56]
100%|██████████| 7/7 [00:00<00:00, 18.31it/s, avg_epoch_loss=0, epoch=57]
100%|██████████| 7/7 [00:00<00:00, 14.66it/s, avg_epoch_loss=0, epoch=58]
100%|██████████| 7/7 [00:00<00:00, 17.15it/s, avg_epoch_loss=0, epoch=59]
100%|██████████| 7/7 [00:00<00:00, 16.29it/s, avg_epoch_loss=0, epoch=60]
100%|██████████| 7/7 [00:00<00:00, 17.63it/s, avg_epoch_loss=0, epoch=61]
100%|██████████| 7/7 [00:00<00:00, 12.92it/s, avg_epoch_loss=0, epoch=62]
100%|██████████| 7/7 [00:00<00:00, 17.

Start validation!!!
Epoch: 99


100%|██████████| 2/2 [01:09<00:00, 34.97s/it]


##### End evaluation!!
PEHE VAL = 8.29e+03


100%|██████████| 7/7 [00:00<00:00, 18.38it/s, avg_epoch_loss=0, epoch=100]
100%|██████████| 7/7 [00:00<00:00, 16.07it/s, avg_epoch_loss=0, epoch=101]
100%|██████████| 7/7 [00:00<00:00, 16.30it/s, avg_epoch_loss=0, epoch=102]
100%|██████████| 7/7 [00:00<00:00, 15.33it/s, avg_epoch_loss=0, epoch=103]
100%|██████████| 7/7 [00:00<00:00, 12.27it/s, avg_epoch_loss=0, epoch=104]
100%|██████████| 7/7 [00:00<00:00, 17.73it/s, avg_epoch_loss=0, epoch=105]
100%|██████████| 7/7 [00:00<00:00, 18.41it/s, avg_epoch_loss=0, epoch=106]
100%|██████████| 7/7 [00:00<00:00, 15.37it/s, avg_epoch_loss=0, epoch=107]
100%|██████████| 7/7 [00:00<00:00, 16.46it/s, avg_epoch_loss=0, epoch=108]
100%|██████████| 7/7 [00:00<00:00, 16.47it/s, avg_epoch_loss=0, epoch=109]
100%|██████████| 7/7 [00:00<00:00, 18.28it/s, avg_epoch_loss=0, epoch=110]
100%|██████████| 7/7 [00:00<00:00, 19.02it/s, avg_epoch_loss=0, epoch=111]
100%|██████████| 7/7 [00:00<00:00, 12.96it/s, avg_epoch_loss=0, epoch=112]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 149


100%|██████████| 2/2 [01:03<00:00, 31.62s/it]


##### End evaluation!!
PEHE VAL = 8.48e+03


100%|██████████| 7/7 [00:00<00:00, 17.79it/s, avg_epoch_loss=0, epoch=150]
100%|██████████| 7/7 [00:00<00:00, 17.88it/s, avg_epoch_loss=0, epoch=151]
100%|██████████| 7/7 [00:00<00:00, 18.24it/s, avg_epoch_loss=0, epoch=152]
100%|██████████| 7/7 [00:00<00:00, 15.89it/s, avg_epoch_loss=0, epoch=153]
100%|██████████| 7/7 [00:00<00:00, 17.71it/s, avg_epoch_loss=0, epoch=154]
100%|██████████| 7/7 [00:00<00:00, 18.19it/s, avg_epoch_loss=0, epoch=155]
100%|██████████| 7/7 [00:00<00:00, 17.78it/s, avg_epoch_loss=0, epoch=156]
100%|██████████| 7/7 [00:00<00:00, 18.35it/s, avg_epoch_loss=0, epoch=157]
100%|██████████| 7/7 [00:00<00:00, 18.47it/s, avg_epoch_loss=0, epoch=158]
100%|██████████| 7/7 [00:00<00:00, 18.48it/s, avg_epoch_loss=0, epoch=159]
100%|██████████| 7/7 [00:00<00:00, 18.66it/s, avg_epoch_loss=0, epoch=160]
100%|██████████| 7/7 [00:00<00:00, 18.10it/s, avg_epoch_loss=0, epoch=161]
100%|██████████| 7/7 [00:00<00:00, 18.58it/s, avg_epoch_loss=0, epoch=162]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 199


100%|██████████| 2/2 [01:02<00:00, 31.00s/it]


##### End evaluation!!
PEHE VAL = 8.64e+03


100%|██████████| 7/7 [00:00<00:00, 18.36it/s, avg_epoch_loss=0, epoch=200]
100%|██████████| 7/7 [00:00<00:00, 14.09it/s, avg_epoch_loss=0, epoch=201]
100%|██████████| 7/7 [00:00<00:00, 17.88it/s, avg_epoch_loss=0, epoch=202]
100%|██████████| 7/7 [00:00<00:00, 18.25it/s, avg_epoch_loss=0, epoch=203]
100%|██████████| 7/7 [00:00<00:00, 18.15it/s, avg_epoch_loss=0, epoch=204]
100%|██████████| 7/7 [00:00<00:00, 18.38it/s, avg_epoch_loss=0, epoch=205]
100%|██████████| 7/7 [00:00<00:00, 18.64it/s, avg_epoch_loss=0, epoch=206]
100%|██████████| 7/7 [00:00<00:00, 17.60it/s, avg_epoch_loss=0, epoch=207]
100%|██████████| 7/7 [00:00<00:00, 17.10it/s, avg_epoch_loss=0, epoch=208]
100%|██████████| 7/7 [00:00<00:00, 18.41it/s, avg_epoch_loss=0, epoch=209]
100%|██████████| 7/7 [00:00<00:00, 18.44it/s, avg_epoch_loss=0, epoch=210]
100%|██████████| 7/7 [00:00<00:00, 18.08it/s, avg_epoch_loss=0, epoch=211]
100%|██████████| 7/7 [00:00<00:00, 18.79it/s, avg_epoch_loss=0, epoch=212]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 249


100%|██████████| 2/2 [01:02<00:00, 31.03s/it]


##### End evaluation!!
PEHE VAL = 8.31e+03


100%|██████████| 7/7 [00:00<00:00, 14.47it/s, avg_epoch_loss=0, epoch=250]
100%|██████████| 7/7 [00:00<00:00, 15.78it/s, avg_epoch_loss=0, epoch=251]
100%|██████████| 7/7 [00:00<00:00, 18.29it/s, avg_epoch_loss=0, epoch=252]
100%|██████████| 7/7 [00:00<00:00, 18.51it/s, avg_epoch_loss=0, epoch=253]
100%|██████████| 7/7 [00:00<00:00, 18.43it/s, avg_epoch_loss=0, epoch=254]
100%|██████████| 7/7 [00:00<00:00, 18.53it/s, avg_epoch_loss=0, epoch=255]
100%|██████████| 7/7 [00:00<00:00, 18.55it/s, avg_epoch_loss=0, epoch=256]
100%|██████████| 7/7 [00:00<00:00, 16.24it/s, avg_epoch_loss=0, epoch=257]
100%|██████████| 7/7 [00:00<00:00, 18.00it/s, avg_epoch_loss=0, epoch=258]
100%|██████████| 7/7 [00:00<00:00, 18.10it/s, avg_epoch_loss=0, epoch=259]
100%|██████████| 7/7 [00:00<00:00, 18.14it/s, avg_epoch_loss=0, epoch=260]
100%|██████████| 7/7 [00:00<00:00, 18.11it/s, avg_epoch_loss=0, epoch=261]
100%|██████████| 7/7 [00:00<00:00, 18.14it/s, avg_epoch_loss=0, epoch=262]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 299


100%|██████████| 2/2 [01:02<00:00, 31.03s/it]


##### End evaluation!!
PEHE VAL = 8.37e+03


100%|██████████| 7/7 [00:00<00:00, 18.29it/s, avg_epoch_loss=0, epoch=300]
100%|██████████| 7/7 [00:00<00:00, 12.83it/s, avg_epoch_loss=0, epoch=301]
100%|██████████| 7/7 [00:00<00:00, 16.78it/s, avg_epoch_loss=0, epoch=302]
100%|██████████| 7/7 [00:00<00:00, 18.86it/s, avg_epoch_loss=0, epoch=303]
100%|██████████| 7/7 [00:00<00:00, 18.92it/s, avg_epoch_loss=0, epoch=304]
100%|██████████| 7/7 [00:00<00:00, 18.72it/s, avg_epoch_loss=0, epoch=305]
100%|██████████| 7/7 [00:00<00:00, 18.81it/s, avg_epoch_loss=0, epoch=306]
100%|██████████| 7/7 [00:00<00:00, 18.40it/s, avg_epoch_loss=0, epoch=307]
100%|██████████| 7/7 [00:00<00:00, 18.16it/s, avg_epoch_loss=0, epoch=308]
100%|██████████| 7/7 [00:00<00:00, 17.62it/s, avg_epoch_loss=0, epoch=309]
100%|██████████| 7/7 [00:00<00:00, 18.36it/s, avg_epoch_loss=0, epoch=310]
100%|██████████| 7/7 [00:00<00:00, 18.15it/s, avg_epoch_loss=0, epoch=311]
100%|██████████| 7/7 [00:00<00:00, 17.78it/s, avg_epoch_loss=0, epoch=312]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 349


100%|██████████| 2/2 [01:02<00:00, 31.03s/it]


##### End evaluation!!
PEHE VAL = 8.91e+03


100%|██████████| 7/7 [00:00<00:00, 17.63it/s, avg_epoch_loss=0, epoch=350]
100%|██████████| 7/7 [00:00<00:00, 18.04it/s, avg_epoch_loss=0, epoch=351]
100%|██████████| 7/7 [00:00<00:00, 18.26it/s, avg_epoch_loss=0, epoch=352]
100%|██████████| 7/7 [00:00<00:00, 18.36it/s, avg_epoch_loss=0, epoch=353]
100%|██████████| 7/7 [00:00<00:00, 18.39it/s, avg_epoch_loss=0, epoch=354]
100%|██████████| 7/7 [00:00<00:00, 18.08it/s, avg_epoch_loss=0, epoch=355]
100%|██████████| 7/7 [00:00<00:00, 18.03it/s, avg_epoch_loss=0, epoch=356]
100%|██████████| 7/7 [00:00<00:00, 17.97it/s, avg_epoch_loss=0, epoch=357]
100%|██████████| 7/7 [00:00<00:00, 18.34it/s, avg_epoch_loss=0, epoch=358]
100%|██████████| 7/7 [00:00<00:00, 18.27it/s, avg_epoch_loss=0, epoch=359]
100%|██████████| 7/7 [00:00<00:00, 18.51it/s, avg_epoch_loss=0, epoch=360]
100%|██████████| 7/7 [00:00<00:00, 14.13it/s, avg_epoch_loss=0, epoch=361]
100%|██████████| 7/7 [00:00<00:00, 17.20it/s, avg_epoch_loss=0, epoch=362]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 399


100%|██████████| 2/2 [00:59<00:00, 29.78s/it]


##### End evaluation!!
PEHE VAL = 8.9e+03


100%|██████████| 7/7 [00:00<00:00, 18.49it/s, avg_epoch_loss=0, epoch=400]
100%|██████████| 7/7 [00:00<00:00, 16.43it/s, avg_epoch_loss=0, epoch=401]
100%|██████████| 7/7 [00:00<00:00, 18.33it/s, avg_epoch_loss=0, epoch=402]
100%|██████████| 7/7 [00:00<00:00, 14.39it/s, avg_epoch_loss=0, epoch=403]
100%|██████████| 7/7 [00:00<00:00, 14.59it/s, avg_epoch_loss=0, epoch=404]
100%|██████████| 7/7 [00:00<00:00, 17.71it/s, avg_epoch_loss=0, epoch=405]
100%|██████████| 7/7 [00:00<00:00, 16.92it/s, avg_epoch_loss=0, epoch=406]
100%|██████████| 7/7 [00:00<00:00, 18.10it/s, avg_epoch_loss=0, epoch=407]
100%|██████████| 7/7 [00:00<00:00, 17.71it/s, avg_epoch_loss=0, epoch=408]
100%|██████████| 7/7 [00:00<00:00, 16.44it/s, avg_epoch_loss=0, epoch=409]
100%|██████████| 7/7 [00:00<00:00, 18.09it/s, avg_epoch_loss=0, epoch=410]
100%|██████████| 7/7 [00:00<00:00, 18.12it/s, avg_epoch_loss=0, epoch=411]
100%|██████████| 7/7 [00:00<00:00, 17.04it/s, avg_epoch_loss=0, epoch=412]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 449


100%|██████████| 2/2 [01:02<00:00, 31.01s/it]


##### End evaluation!!
PEHE VAL = 8.03e+03


100%|██████████| 7/7 [00:00<00:00, 17.76it/s, avg_epoch_loss=0, epoch=450]
100%|██████████| 7/7 [00:00<00:00, 18.26it/s, avg_epoch_loss=0, epoch=451]
100%|██████████| 7/7 [00:00<00:00, 17.62it/s, avg_epoch_loss=0, epoch=452]
100%|██████████| 7/7 [00:00<00:00, 18.00it/s, avg_epoch_loss=0, epoch=453]
100%|██████████| 7/7 [00:00<00:00, 18.08it/s, avg_epoch_loss=0, epoch=454]
100%|██████████| 7/7 [00:00<00:00, 17.76it/s, avg_epoch_loss=0, epoch=455]
100%|██████████| 7/7 [00:00<00:00, 17.91it/s, avg_epoch_loss=0, epoch=456]
100%|██████████| 7/7 [00:00<00:00, 17.57it/s, avg_epoch_loss=0, epoch=457]
100%|██████████| 7/7 [00:00<00:00, 17.74it/s, avg_epoch_loss=0, epoch=458]
100%|██████████| 7/7 [00:00<00:00, 17.42it/s, avg_epoch_loss=0, epoch=459]
100%|██████████| 7/7 [00:00<00:00, 18.13it/s, avg_epoch_loss=0, epoch=460]
100%|██████████| 7/7 [00:00<00:00, 18.31it/s, avg_epoch_loss=0, epoch=461]
100%|██████████| 7/7 [00:00<00:00, 18.15it/s, avg_epoch_loss=0, epoch=462]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 499


100%|██████████| 2/2 [01:01<00:00, 30.56s/it]


##### End evaluation!!
PEHE VAL = 8.54e+03
Training complete.
Evaluating the model...


100%|██████████| 12/12 [14:07<00:00, 70.65s/it]
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name           | Type        | Params | Mode 
-------------------------------------------------------
0 | balancer       | Sequential  | 1.6 K  | train
1 | prior_encoder0 | Sequential  | 0      | train
2 | prior_encoder1 | Sequential  | 0      | train
3 | cond_mean0     | Sequential  | 701    | train
4 | cond_mean1     | Sequential  | 701    | train
5 | cond_std0      | Sequential  | 701    | train
6 | cond_std1      | Sequential  | 701    | train
7 | flows0         | DSFMarginal | 892    | train
8 | flows1         | DSFMarginal | 892    | train
-------------------------------------------------------
6.2 K     Trainable params
0         Non-trainable params
6.2 K     Total params
0.025     Total estimated model params size (MB)
45        Modules in train mode
0         Modules in eval mode


Fit and evaluate NOFLITE ...


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_steps=10000` reached.


Fit and evaluate BART ...


Multiprocess sampling (4 chains in 4 jobs)
PGBART: [mu]


Output()

Sampling 4 chains for 500 tune and 500 draw iterations (2_000 + 2_000 draws total) took 19 seconds.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
Sampling: [mu, y]


Output()

X_extended (3000, 11)


Sampling: [mu, y]


Output()

Sampling: [mu, y]


Output()

X_extended (3000, 11)


Sampling: [mu, y]


Output()

Fit and evaluate CMGP ...
Fit and evaluate CEVAE ...
Using cpu


100%|██████████| 7000/7000 [00:32<00:00, 215.58it/s]


Fit and evaluate FCCN ...


100%|██████████| 20000/20000 [02:33<00:00, 130.70it/s]


Fit and evaluate GANITE ...


Training Counterfactual GAN: 100%|██████████| 10000/10000 [08:16<00:00, 20.15it/s]
Training ITE GAN: 100%|██████████| 10000/10000 [08:49<00:00, 18.88it/s]


Fit and evaluate DKLITE ...


  5%|▌         | 51/1000 [00:02<00:49, 19.30it/s]


Fit and evaluate diffpo ...


100%|██████████| 7/7 [00:00<00:00, 11.59it/s, avg_epoch_loss=0, epoch=0]
100%|██████████| 7/7 [00:00<00:00, 18.41it/s, avg_epoch_loss=0, epoch=1]
100%|██████████| 7/7 [00:00<00:00, 13.59it/s, avg_epoch_loss=0, epoch=2]
100%|██████████| 7/7 [00:00<00:00, 14.56it/s, avg_epoch_loss=0, epoch=3]
100%|██████████| 7/7 [00:00<00:00, 16.17it/s, avg_epoch_loss=0, epoch=4]
100%|██████████| 7/7 [00:00<00:00, 16.31it/s, avg_epoch_loss=0, epoch=5]
100%|██████████| 7/7 [00:00<00:00, 13.50it/s, avg_epoch_loss=0, epoch=6]
100%|██████████| 7/7 [00:00<00:00, 19.09it/s, avg_epoch_loss=0, epoch=7]
100%|██████████| 7/7 [00:00<00:00, 19.40it/s, avg_epoch_loss=0, epoch=8]
100%|██████████| 7/7 [00:00<00:00, 17.74it/s, avg_epoch_loss=0, epoch=9]
100%|██████████| 7/7 [00:00<00:00, 17.56it/s, avg_epoch_loss=0, epoch=10]
100%|██████████| 7/7 [00:00<00:00, 18.19it/s, avg_epoch_loss=0, epoch=11]
100%|██████████| 7/7 [00:00<00:00, 17.37it/s, avg_epoch_loss=0, epoch=12]
100%|██████████| 7/7 [00:00<00:00, 17.86it/s, av

Start validation!!!
Epoch: 49


100%|██████████| 2/2 [01:08<00:00, 34.11s/it]


##### End evaluation!!
PEHE VAL = 8.1e+03


100%|██████████| 7/7 [00:00<00:00, 18.08it/s, avg_epoch_loss=0, epoch=50]
100%|██████████| 7/7 [00:00<00:00, 14.36it/s, avg_epoch_loss=0, epoch=51]
100%|██████████| 7/7 [00:00<00:00, 18.71it/s, avg_epoch_loss=0, epoch=52]
100%|██████████| 7/7 [00:00<00:00, 18.94it/s, avg_epoch_loss=0, epoch=53]
100%|██████████| 7/7 [00:00<00:00, 18.75it/s, avg_epoch_loss=0, epoch=54]
100%|██████████| 7/7 [00:00<00:00, 18.58it/s, avg_epoch_loss=0, epoch=55]
100%|██████████| 7/7 [00:00<00:00, 19.03it/s, avg_epoch_loss=0, epoch=56]
100%|██████████| 7/7 [00:00<00:00, 12.84it/s, avg_epoch_loss=0, epoch=57]
100%|██████████| 7/7 [00:00<00:00, 16.73it/s, avg_epoch_loss=0, epoch=58]
100%|██████████| 7/7 [00:00<00:00, 16.34it/s, avg_epoch_loss=0, epoch=59]
100%|██████████| 7/7 [00:00<00:00, 14.66it/s, avg_epoch_loss=0, epoch=60]
100%|██████████| 7/7 [00:00<00:00, 15.40it/s, avg_epoch_loss=0, epoch=61]
100%|██████████| 7/7 [00:00<00:00, 18.21it/s, avg_epoch_loss=0, epoch=62]
100%|██████████| 7/7 [00:00<00:00, 18.

Start validation!!!
Epoch: 99


100%|██████████| 2/2 [01:07<00:00, 33.59s/it]


##### End evaluation!!
PEHE VAL = 8.29e+03


100%|██████████| 7/7 [00:00<00:00, 17.93it/s, avg_epoch_loss=0, epoch=100]
100%|██████████| 7/7 [00:00<00:00, 17.14it/s, avg_epoch_loss=0, epoch=101]
100%|██████████| 7/7 [00:00<00:00, 17.73it/s, avg_epoch_loss=0, epoch=102]
100%|██████████| 7/7 [00:00<00:00, 16.68it/s, avg_epoch_loss=0, epoch=103]
100%|██████████| 7/7 [00:00<00:00, 16.60it/s, avg_epoch_loss=0, epoch=104]
100%|██████████| 7/7 [00:00<00:00, 17.68it/s, avg_epoch_loss=0, epoch=105]
100%|██████████| 7/7 [00:00<00:00, 17.54it/s, avg_epoch_loss=0, epoch=106]
100%|██████████| 7/7 [00:00<00:00, 18.01it/s, avg_epoch_loss=0, epoch=107]
100%|██████████| 7/7 [00:00<00:00, 18.19it/s, avg_epoch_loss=0, epoch=108]
100%|██████████| 7/7 [00:00<00:00, 17.95it/s, avg_epoch_loss=0, epoch=109]
100%|██████████| 7/7 [00:00<00:00, 18.15it/s, avg_epoch_loss=0, epoch=110]
100%|██████████| 7/7 [00:00<00:00, 16.15it/s, avg_epoch_loss=0, epoch=111]
100%|██████████| 7/7 [00:00<00:00, 15.89it/s, avg_epoch_loss=0, epoch=112]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 149


100%|██████████| 2/2 [01:19<00:00, 39.72s/it]


##### End evaluation!!
PEHE VAL = 8.48e+03


100%|██████████| 7/7 [00:00<00:00, 17.11it/s, avg_epoch_loss=0, epoch=150]
100%|██████████| 7/7 [00:00<00:00, 17.52it/s, avg_epoch_loss=0, epoch=151]
100%|██████████| 7/7 [00:00<00:00, 17.09it/s, avg_epoch_loss=0, epoch=152]
100%|██████████| 7/7 [00:00<00:00, 14.05it/s, avg_epoch_loss=0, epoch=153]
100%|██████████| 7/7 [00:00<00:00, 12.67it/s, avg_epoch_loss=0, epoch=154]
100%|██████████| 7/7 [00:00<00:00, 17.18it/s, avg_epoch_loss=0, epoch=155]
100%|██████████| 7/7 [00:00<00:00, 18.63it/s, avg_epoch_loss=0, epoch=156]
100%|██████████| 7/7 [00:00<00:00, 18.49it/s, avg_epoch_loss=0, epoch=157]
100%|██████████| 7/7 [00:00<00:00, 18.64it/s, avg_epoch_loss=0, epoch=158]
100%|██████████| 7/7 [00:00<00:00, 17.72it/s, avg_epoch_loss=0, epoch=159]
100%|██████████| 7/7 [00:00<00:00, 17.59it/s, avg_epoch_loss=0, epoch=160]
100%|██████████| 7/7 [00:00<00:00, 17.24it/s, avg_epoch_loss=0, epoch=161]
100%|██████████| 7/7 [00:00<00:00, 17.24it/s, avg_epoch_loss=0, epoch=162]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 199


100%|██████████| 2/2 [01:21<00:00, 40.52s/it]


##### End evaluation!!
PEHE VAL = 8.64e+03


100%|██████████| 7/7 [00:00<00:00, 17.17it/s, avg_epoch_loss=0, epoch=200]
100%|██████████| 7/7 [00:00<00:00, 13.33it/s, avg_epoch_loss=0, epoch=201]
100%|██████████| 7/7 [00:00<00:00, 17.75it/s, avg_epoch_loss=0, epoch=202]
100%|██████████| 7/7 [00:00<00:00, 17.20it/s, avg_epoch_loss=0, epoch=203]
100%|██████████| 7/7 [00:00<00:00, 17.22it/s, avg_epoch_loss=0, epoch=204]
100%|██████████| 7/7 [00:00<00:00, 17.30it/s, avg_epoch_loss=0, epoch=205]
100%|██████████| 7/7 [00:00<00:00, 16.93it/s, avg_epoch_loss=0, epoch=206]
100%|██████████| 7/7 [00:00<00:00, 17.15it/s, avg_epoch_loss=0, epoch=207]
100%|██████████| 7/7 [00:00<00:00, 17.10it/s, avg_epoch_loss=0, epoch=208]
100%|██████████| 7/7 [00:00<00:00, 16.59it/s, avg_epoch_loss=0, epoch=209]
100%|██████████| 7/7 [00:00<00:00, 16.84it/s, avg_epoch_loss=0, epoch=210]
100%|██████████| 7/7 [00:00<00:00, 16.96it/s, avg_epoch_loss=0, epoch=211]
100%|██████████| 7/7 [00:00<00:00, 17.21it/s, avg_epoch_loss=0, epoch=212]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 249


100%|██████████| 2/2 [01:14<00:00, 37.21s/it]


##### End evaluation!!
PEHE VAL = 8.31e+03


100%|██████████| 7/7 [00:00<00:00, 17.97it/s, avg_epoch_loss=0, epoch=250]
100%|██████████| 7/7 [00:00<00:00, 17.57it/s, avg_epoch_loss=0, epoch=251]
100%|██████████| 7/7 [00:00<00:00, 13.57it/s, avg_epoch_loss=0, epoch=252]
100%|██████████| 7/7 [00:00<00:00, 16.33it/s, avg_epoch_loss=0, epoch=253]
100%|██████████| 7/7 [00:00<00:00, 17.08it/s, avg_epoch_loss=0, epoch=254]
100%|██████████| 7/7 [00:00<00:00, 17.13it/s, avg_epoch_loss=0, epoch=255]
100%|██████████| 7/7 [00:00<00:00, 16.80it/s, avg_epoch_loss=0, epoch=256]
100%|██████████| 7/7 [00:00<00:00, 16.90it/s, avg_epoch_loss=0, epoch=257]
100%|██████████| 7/7 [00:00<00:00, 17.08it/s, avg_epoch_loss=0, epoch=258]
100%|██████████| 7/7 [00:00<00:00, 15.46it/s, avg_epoch_loss=0, epoch=259]
100%|██████████| 7/7 [00:00<00:00, 17.05it/s, avg_epoch_loss=0, epoch=260]
100%|██████████| 7/7 [00:00<00:00, 16.87it/s, avg_epoch_loss=0, epoch=261]
100%|██████████| 7/7 [00:00<00:00, 13.06it/s, avg_epoch_loss=0, epoch=262]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 299


100%|██████████| 2/2 [01:21<00:00, 40.69s/it]


##### End evaluation!!
PEHE VAL = 8.37e+03


100%|██████████| 7/7 [00:00<00:00, 16.85it/s, avg_epoch_loss=0, epoch=300]
100%|██████████| 7/7 [00:00<00:00, 16.73it/s, avg_epoch_loss=0, epoch=301]
100%|██████████| 7/7 [00:00<00:00, 14.56it/s, avg_epoch_loss=0, epoch=302]
100%|██████████| 7/7 [00:00<00:00, 17.76it/s, avg_epoch_loss=0, epoch=303]
100%|██████████| 7/7 [00:00<00:00, 17.52it/s, avg_epoch_loss=0, epoch=304]
100%|██████████| 7/7 [00:00<00:00, 17.23it/s, avg_epoch_loss=0, epoch=305]
100%|██████████| 7/7 [00:00<00:00, 17.54it/s, avg_epoch_loss=0, epoch=306]
100%|██████████| 7/7 [00:00<00:00, 12.93it/s, avg_epoch_loss=0, epoch=307]
100%|██████████| 7/7 [00:00<00:00, 11.40it/s, avg_epoch_loss=0, epoch=308]
100%|██████████| 7/7 [00:00<00:00, 16.28it/s, avg_epoch_loss=0, epoch=309]
100%|██████████| 7/7 [00:00<00:00, 17.35it/s, avg_epoch_loss=0, epoch=310]
100%|██████████| 7/7 [00:00<00:00, 17.45it/s, avg_epoch_loss=0, epoch=311]
100%|██████████| 7/7 [00:00<00:00, 13.76it/s, avg_epoch_loss=0, epoch=312]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 349


100%|██████████| 2/2 [01:21<00:00, 40.87s/it]


##### End evaluation!!
PEHE VAL = 8.91e+03


100%|██████████| 7/7 [00:00<00:00, 13.79it/s, avg_epoch_loss=0, epoch=350]
100%|██████████| 7/7 [00:00<00:00, 16.80it/s, avg_epoch_loss=0, epoch=351]
100%|██████████| 7/7 [00:00<00:00, 13.38it/s, avg_epoch_loss=0, epoch=352]
100%|██████████| 7/7 [00:00<00:00, 17.48it/s, avg_epoch_loss=0, epoch=353]
100%|██████████| 7/7 [00:00<00:00, 16.93it/s, avg_epoch_loss=0, epoch=354]
100%|██████████| 7/7 [00:00<00:00, 17.50it/s, avg_epoch_loss=0, epoch=355]
100%|██████████| 7/7 [00:00<00:00, 16.71it/s, avg_epoch_loss=0, epoch=356]
100%|██████████| 7/7 [00:00<00:00, 16.85it/s, avg_epoch_loss=0, epoch=357]
100%|██████████| 7/7 [00:00<00:00, 17.19it/s, avg_epoch_loss=0, epoch=358]
100%|██████████| 7/7 [00:00<00:00, 13.47it/s, avg_epoch_loss=0, epoch=359]
100%|██████████| 7/7 [00:00<00:00, 14.40it/s, avg_epoch_loss=0, epoch=360]
100%|██████████| 7/7 [00:00<00:00, 16.84it/s, avg_epoch_loss=0, epoch=361]
100%|██████████| 7/7 [00:00<00:00, 13.59it/s, avg_epoch_loss=0, epoch=362]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 399


100%|██████████| 2/2 [01:20<00:00, 40.24s/it]


##### End evaluation!!
PEHE VAL = 8.9e+03


100%|██████████| 7/7 [00:00<00:00, 16.98it/s, avg_epoch_loss=0, epoch=400]
100%|██████████| 7/7 [00:00<00:00, 17.53it/s, avg_epoch_loss=0, epoch=401]
100%|██████████| 7/7 [00:00<00:00, 16.71it/s, avg_epoch_loss=0, epoch=402]
100%|██████████| 7/7 [00:00<00:00, 16.98it/s, avg_epoch_loss=0, epoch=403]
100%|██████████| 7/7 [00:00<00:00, 12.48it/s, avg_epoch_loss=0, epoch=404]
100%|██████████| 7/7 [00:00<00:00, 17.30it/s, avg_epoch_loss=0, epoch=405]
100%|██████████| 7/7 [00:00<00:00, 16.90it/s, avg_epoch_loss=0, epoch=406]
100%|██████████| 7/7 [00:00<00:00, 16.58it/s, avg_epoch_loss=0, epoch=407]
100%|██████████| 7/7 [00:00<00:00, 17.19it/s, avg_epoch_loss=0, epoch=408]
100%|██████████| 7/7 [00:00<00:00, 14.39it/s, avg_epoch_loss=0, epoch=409]
100%|██████████| 7/7 [00:00<00:00, 16.54it/s, avg_epoch_loss=0, epoch=410]
100%|██████████| 7/7 [00:00<00:00, 15.95it/s, avg_epoch_loss=0, epoch=411]
100%|██████████| 7/7 [00:00<00:00, 16.82it/s, avg_epoch_loss=0, epoch=412]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 449


100%|██████████| 2/2 [01:21<00:00, 40.75s/it]


##### End evaluation!!
PEHE VAL = 8.03e+03


100%|██████████| 7/7 [00:00<00:00, 16.81it/s, avg_epoch_loss=0, epoch=450]
100%|██████████| 7/7 [00:00<00:00, 16.99it/s, avg_epoch_loss=0, epoch=451]
100%|██████████| 7/7 [00:00<00:00, 16.85it/s, avg_epoch_loss=0, epoch=452]
100%|██████████| 7/7 [00:00<00:00, 11.98it/s, avg_epoch_loss=0, epoch=453]
100%|██████████| 7/7 [00:00<00:00, 15.96it/s, avg_epoch_loss=0, epoch=454]
100%|██████████| 7/7 [00:00<00:00, 16.94it/s, avg_epoch_loss=0, epoch=455]
100%|██████████| 7/7 [00:00<00:00, 16.80it/s, avg_epoch_loss=0, epoch=456]
100%|██████████| 7/7 [00:00<00:00, 16.73it/s, avg_epoch_loss=0, epoch=457]
100%|██████████| 7/7 [00:00<00:00, 16.83it/s, avg_epoch_loss=0, epoch=458]
100%|██████████| 7/7 [00:00<00:00, 17.19it/s, avg_epoch_loss=0, epoch=459]
100%|██████████| 7/7 [00:00<00:00, 13.53it/s, avg_epoch_loss=0, epoch=460]
100%|██████████| 7/7 [00:00<00:00, 15.58it/s, avg_epoch_loss=0, epoch=461]
100%|██████████| 7/7 [00:00<00:00, 16.51it/s, avg_epoch_loss=0, epoch=462]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 499


100%|██████████| 2/2 [01:20<00:00, 40.47s/it]


##### End evaluation!!
PEHE VAL = 8.54e+03
Training complete.
Evaluating the model...


100%|██████████| 12/12 [17:40<00:00, 88.40s/it]
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name           | Type        | Params | Mode 
-------------------------------------------------------
0 | balancer       | Sequential  | 1.6 K  | train
1 | prior_encoder0 | Sequential  | 0      | train
2 | prior_encoder1 | Sequential  | 0      | train
3 | cond_mean0     | Sequential  | 701    | train
4 | cond_mean1     | Sequential  | 701    | train
5 | cond_std0      | Sequential  | 701    | train
6 | cond_std1      | Sequential  | 701    | train
7 | flows0         | DSFMarginal | 892    | train
8 | flows1         | DSFMarginal | 892    | train
-------------------------------------------------------
6.2 K     Trainable params
0         Non-trainable params
6.2 K     Total params
0.025     Total estimated model params size (MB)
45        Modules in train mode
0         Modules in eval mode


Fit and evaluate NOFLITE ...


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_steps=10000` reached.


Fit and evaluate BART ...


Multiprocess sampling (4 chains in 4 jobs)
PGBART: [mu]


Output()

Sampling 4 chains for 500 tune and 500 draw iterations (2_000 + 2_000 draws total) took 18 seconds.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
Sampling: [mu, y]


Output()

X_extended (3000, 11)


Sampling: [mu, y]


Output()

Sampling: [mu, y]


Output()

X_extended (3000, 11)


Sampling: [mu, y]


Output()

Fit and evaluate CMGP ...
Fit and evaluate CEVAE ...
Using cpu


100%|██████████| 7000/7000 [00:35<00:00, 195.60it/s]


Fit and evaluate FCCN ...


100%|██████████| 20000/20000 [02:29<00:00, 134.09it/s]


Fit and evaluate GANITE ...


Training Counterfactual GAN: 100%|██████████| 10000/10000 [08:46<00:00, 19.00it/s]
Training ITE GAN: 100%|██████████| 10000/10000 [09:46<00:00, 17.05it/s]


Fit and evaluate DKLITE ...


  5%|▌         | 51/1000 [00:02<00:55, 17.15it/s]


Fit and evaluate diffpo ...


100%|██████████| 7/7 [00:00<00:00, 16.23it/s, avg_epoch_loss=0, epoch=0]
100%|██████████| 7/7 [00:00<00:00, 16.53it/s, avg_epoch_loss=0, epoch=1]
100%|██████████| 7/7 [00:00<00:00, 14.47it/s, avg_epoch_loss=0, epoch=2]
100%|██████████| 7/7 [00:00<00:00, 13.93it/s, avg_epoch_loss=0, epoch=3]
100%|██████████| 7/7 [00:00<00:00, 17.39it/s, avg_epoch_loss=0, epoch=4]
100%|██████████| 7/7 [00:00<00:00, 16.13it/s, avg_epoch_loss=0, epoch=5]
100%|██████████| 7/7 [00:00<00:00, 14.77it/s, avg_epoch_loss=0, epoch=6]
100%|██████████| 7/7 [00:00<00:00, 17.07it/s, avg_epoch_loss=0, epoch=7]
100%|██████████| 7/7 [00:00<00:00, 17.31it/s, avg_epoch_loss=0, epoch=8]
100%|██████████| 7/7 [00:00<00:00, 16.25it/s, avg_epoch_loss=0, epoch=9]
100%|██████████| 7/7 [00:00<00:00, 14.94it/s, avg_epoch_loss=0, epoch=10]
100%|██████████| 7/7 [00:00<00:00, 15.01it/s, avg_epoch_loss=0, epoch=11]
100%|██████████| 7/7 [00:00<00:00, 15.71it/s, avg_epoch_loss=0, epoch=12]
100%|██████████| 7/7 [00:00<00:00, 16.19it/s, av

Start validation!!!
Epoch: 49


100%|██████████| 2/2 [01:29<00:00, 44.80s/it]


##### End evaluation!!
PEHE VAL = 8.1e+03


100%|██████████| 7/7 [00:00<00:00, 12.36it/s, avg_epoch_loss=0, epoch=50]
100%|██████████| 7/7 [00:00<00:00, 11.16it/s, avg_epoch_loss=0, epoch=51]
100%|██████████| 7/7 [00:00<00:00, 12.45it/s, avg_epoch_loss=0, epoch=52]
100%|██████████| 7/7 [00:00<00:00, 15.71it/s, avg_epoch_loss=0, epoch=53]
100%|██████████| 7/7 [00:00<00:00, 14.25it/s, avg_epoch_loss=0, epoch=54]
100%|██████████| 7/7 [00:00<00:00, 15.12it/s, avg_epoch_loss=0, epoch=55]
100%|██████████| 7/7 [00:00<00:00, 13.86it/s, avg_epoch_loss=0, epoch=56]
100%|██████████| 7/7 [00:00<00:00, 17.90it/s, avg_epoch_loss=0, epoch=57]
100%|██████████| 7/7 [00:00<00:00, 13.97it/s, avg_epoch_loss=0, epoch=58]
100%|██████████| 7/7 [00:00<00:00, 15.42it/s, avg_epoch_loss=0, epoch=59]
100%|██████████| 7/7 [00:00<00:00, 15.79it/s, avg_epoch_loss=0, epoch=60]
100%|██████████| 7/7 [00:00<00:00, 16.49it/s, avg_epoch_loss=0, epoch=61]
100%|██████████| 7/7 [00:00<00:00, 16.70it/s, avg_epoch_loss=0, epoch=62]
100%|██████████| 7/7 [00:00<00:00, 16.

Start validation!!!
Epoch: 99


100%|██████████| 2/2 [01:23<00:00, 41.54s/it]


##### End evaluation!!
PEHE VAL = 8.29e+03


100%|██████████| 7/7 [00:00<00:00, 15.16it/s, avg_epoch_loss=0, epoch=100]
100%|██████████| 7/7 [00:00<00:00, 14.67it/s, avg_epoch_loss=0, epoch=101]
100%|██████████| 7/7 [00:00<00:00, 15.88it/s, avg_epoch_loss=0, epoch=102]
100%|██████████| 7/7 [00:00<00:00, 13.62it/s, avg_epoch_loss=0, epoch=103]
100%|██████████| 7/7 [00:00<00:00, 16.35it/s, avg_epoch_loss=0, epoch=104]
100%|██████████| 7/7 [00:00<00:00, 16.19it/s, avg_epoch_loss=0, epoch=105]
100%|██████████| 7/7 [00:00<00:00, 16.63it/s, avg_epoch_loss=0, epoch=106]
100%|██████████| 7/7 [00:00<00:00, 16.45it/s, avg_epoch_loss=0, epoch=107]
100%|██████████| 7/7 [00:00<00:00, 17.03it/s, avg_epoch_loss=0, epoch=108]
100%|██████████| 7/7 [00:00<00:00, 16.55it/s, avg_epoch_loss=0, epoch=109]
100%|██████████| 7/7 [00:00<00:00, 17.98it/s, avg_epoch_loss=0, epoch=110]
100%|██████████| 7/7 [00:00<00:00, 13.03it/s, avg_epoch_loss=0, epoch=111]
100%|██████████| 7/7 [00:00<00:00, 13.54it/s, avg_epoch_loss=0, epoch=112]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 149


100%|██████████| 2/2 [01:17<00:00, 38.51s/it]


##### End evaluation!!
PEHE VAL = 8.48e+03


100%|██████████| 7/7 [00:00<00:00, 15.12it/s, avg_epoch_loss=0, epoch=150]
100%|██████████| 7/7 [00:00<00:00, 15.46it/s, avg_epoch_loss=0, epoch=151]
100%|██████████| 7/7 [00:00<00:00, 15.28it/s, avg_epoch_loss=0, epoch=152]
100%|██████████| 7/7 [00:00<00:00, 15.13it/s, avg_epoch_loss=0, epoch=153]
100%|██████████| 7/7 [00:00<00:00, 15.59it/s, avg_epoch_loss=0, epoch=154]
100%|██████████| 7/7 [00:00<00:00, 16.13it/s, avg_epoch_loss=0, epoch=155]
100%|██████████| 7/7 [00:00<00:00, 15.97it/s, avg_epoch_loss=0, epoch=156]
100%|██████████| 7/7 [00:00<00:00, 15.93it/s, avg_epoch_loss=0, epoch=157]
100%|██████████| 7/7 [00:00<00:00, 16.24it/s, avg_epoch_loss=0, epoch=158]
100%|██████████| 7/7 [00:00<00:00, 15.57it/s, avg_epoch_loss=0, epoch=159]
100%|██████████| 7/7 [00:00<00:00, 16.04it/s, avg_epoch_loss=0, epoch=160]
100%|██████████| 7/7 [00:00<00:00, 16.36it/s, avg_epoch_loss=0, epoch=161]
100%|██████████| 7/7 [00:00<00:00, 17.07it/s, avg_epoch_loss=0, epoch=162]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 199


100%|██████████| 2/2 [01:16<00:00, 38.11s/it]


##### End evaluation!!
PEHE VAL = 8.64e+03


100%|██████████| 7/7 [00:00<00:00, 17.22it/s, avg_epoch_loss=0, epoch=200]
100%|██████████| 7/7 [00:00<00:00, 18.40it/s, avg_epoch_loss=0, epoch=201]
100%|██████████| 7/7 [00:00<00:00, 20.21it/s, avg_epoch_loss=0, epoch=202]
100%|██████████| 7/7 [00:00<00:00, 19.40it/s, avg_epoch_loss=0, epoch=203]
100%|██████████| 7/7 [00:00<00:00, 18.81it/s, avg_epoch_loss=0, epoch=204]
100%|██████████| 7/7 [00:00<00:00, 18.72it/s, avg_epoch_loss=0, epoch=205]
100%|██████████| 7/7 [00:00<00:00, 18.24it/s, avg_epoch_loss=0, epoch=206]
100%|██████████| 7/7 [00:00<00:00, 14.31it/s, avg_epoch_loss=0, epoch=207]
100%|██████████| 7/7 [00:00<00:00, 16.08it/s, avg_epoch_loss=0, epoch=208]
100%|██████████| 7/7 [00:00<00:00, 16.08it/s, avg_epoch_loss=0, epoch=209]
100%|██████████| 7/7 [00:00<00:00, 15.65it/s, avg_epoch_loss=0, epoch=210]
100%|██████████| 7/7 [00:00<00:00, 14.40it/s, avg_epoch_loss=0, epoch=211]
100%|██████████| 7/7 [00:00<00:00, 13.11it/s, avg_epoch_loss=0, epoch=212]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 249


100%|██████████| 2/2 [01:22<00:00, 41.31s/it]


##### End evaluation!!
PEHE VAL = 8.31e+03


100%|██████████| 7/7 [00:00<00:00, 14.65it/s, avg_epoch_loss=0, epoch=250]
100%|██████████| 7/7 [00:00<00:00, 15.62it/s, avg_epoch_loss=0, epoch=251]
100%|██████████| 7/7 [00:00<00:00, 15.11it/s, avg_epoch_loss=0, epoch=252]
100%|██████████| 7/7 [00:00<00:00, 13.71it/s, avg_epoch_loss=0, epoch=253]
100%|██████████| 7/7 [00:00<00:00, 11.88it/s, avg_epoch_loss=0, epoch=254]
100%|██████████| 7/7 [00:00<00:00, 15.36it/s, avg_epoch_loss=0, epoch=255]
100%|██████████| 7/7 [00:00<00:00, 14.38it/s, avg_epoch_loss=0, epoch=256]
100%|██████████| 7/7 [00:00<00:00, 14.59it/s, avg_epoch_loss=0, epoch=257]
100%|██████████| 7/7 [00:00<00:00, 16.27it/s, avg_epoch_loss=0, epoch=258]
100%|██████████| 7/7 [00:00<00:00, 15.60it/s, avg_epoch_loss=0, epoch=259]
100%|██████████| 7/7 [00:00<00:00, 13.45it/s, avg_epoch_loss=0, epoch=260]
100%|██████████| 7/7 [00:00<00:00, 10.87it/s, avg_epoch_loss=0, epoch=261]
100%|██████████| 7/7 [00:00<00:00, 16.06it/s, avg_epoch_loss=0, epoch=262]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 299


100%|██████████| 2/2 [01:13<00:00, 36.72s/it]


##### End evaluation!!
PEHE VAL = 8.37e+03


100%|██████████| 7/7 [00:00<00:00, 15.93it/s, avg_epoch_loss=0, epoch=300]
100%|██████████| 7/7 [00:00<00:00, 16.57it/s, avg_epoch_loss=0, epoch=301]
100%|██████████| 7/7 [00:00<00:00, 16.35it/s, avg_epoch_loss=0, epoch=302]
100%|██████████| 7/7 [00:00<00:00, 18.97it/s, avg_epoch_loss=0, epoch=303]
100%|██████████| 7/7 [00:00<00:00, 17.54it/s, avg_epoch_loss=0, epoch=304]
100%|██████████| 7/7 [00:00<00:00, 16.22it/s, avg_epoch_loss=0, epoch=305]
100%|██████████| 7/7 [00:00<00:00, 17.39it/s, avg_epoch_loss=0, epoch=306]
100%|██████████| 7/7 [00:00<00:00, 17.79it/s, avg_epoch_loss=0, epoch=307]
100%|██████████| 7/7 [00:00<00:00, 17.52it/s, avg_epoch_loss=0, epoch=308]
100%|██████████| 7/7 [00:00<00:00, 17.39it/s, avg_epoch_loss=0, epoch=309]
100%|██████████| 7/7 [00:00<00:00, 16.85it/s, avg_epoch_loss=0, epoch=310]
100%|██████████| 7/7 [00:00<00:00, 16.06it/s, avg_epoch_loss=0, epoch=311]
100%|██████████| 7/7 [00:00<00:00, 15.13it/s, avg_epoch_loss=0, epoch=312]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 349


100%|██████████| 2/2 [01:15<00:00, 37.67s/it]


##### End evaluation!!
PEHE VAL = 8.91e+03


100%|██████████| 7/7 [00:00<00:00, 16.50it/s, avg_epoch_loss=0, epoch=350]
100%|██████████| 7/7 [00:00<00:00, 16.71it/s, avg_epoch_loss=0, epoch=351]
100%|██████████| 7/7 [00:00<00:00, 15.88it/s, avg_epoch_loss=0, epoch=352]
100%|██████████| 7/7 [00:00<00:00, 16.35it/s, avg_epoch_loss=0, epoch=353]
100%|██████████| 7/7 [00:00<00:00, 16.89it/s, avg_epoch_loss=0, epoch=354]
100%|██████████| 7/7 [00:00<00:00, 16.63it/s, avg_epoch_loss=0, epoch=355]
100%|██████████| 7/7 [00:00<00:00, 14.47it/s, avg_epoch_loss=0, epoch=356]
100%|██████████| 7/7 [00:00<00:00, 15.18it/s, avg_epoch_loss=0, epoch=357]
100%|██████████| 7/7 [00:00<00:00, 15.42it/s, avg_epoch_loss=0, epoch=358]
100%|██████████| 7/7 [00:00<00:00, 15.68it/s, avg_epoch_loss=0, epoch=359]
100%|██████████| 7/7 [00:00<00:00, 14.66it/s, avg_epoch_loss=0, epoch=360]
100%|██████████| 7/7 [00:00<00:00, 17.07it/s, avg_epoch_loss=0, epoch=361]
100%|██████████| 7/7 [00:00<00:00, 16.72it/s, avg_epoch_loss=0, epoch=362]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 399


100%|██████████| 2/2 [01:15<00:00, 37.97s/it]


##### End evaluation!!
PEHE VAL = 8.9e+03


100%|██████████| 7/7 [00:00<00:00, 16.09it/s, avg_epoch_loss=0, epoch=400]
100%|██████████| 7/7 [00:00<00:00, 15.24it/s, avg_epoch_loss=0, epoch=401]
100%|██████████| 7/7 [00:00<00:00, 17.77it/s, avg_epoch_loss=0, epoch=402]
100%|██████████| 7/7 [00:00<00:00, 17.84it/s, avg_epoch_loss=0, epoch=403]
100%|██████████| 7/7 [00:00<00:00, 18.02it/s, avg_epoch_loss=0, epoch=404]
100%|██████████| 7/7 [00:00<00:00, 17.18it/s, avg_epoch_loss=0, epoch=405]
100%|██████████| 7/7 [00:00<00:00, 15.61it/s, avg_epoch_loss=0, epoch=406]
100%|██████████| 7/7 [00:00<00:00, 15.48it/s, avg_epoch_loss=0, epoch=407]
100%|██████████| 7/7 [00:00<00:00, 13.93it/s, avg_epoch_loss=0, epoch=408]
100%|██████████| 7/7 [00:00<00:00, 14.94it/s, avg_epoch_loss=0, epoch=409]
100%|██████████| 7/7 [00:00<00:00, 17.26it/s, avg_epoch_loss=0, epoch=410]
100%|██████████| 7/7 [00:00<00:00, 16.84it/s, avg_epoch_loss=0, epoch=411]
100%|██████████| 7/7 [00:00<00:00, 16.69it/s, avg_epoch_loss=0, epoch=412]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 449


100%|██████████| 2/2 [01:12<00:00, 36.47s/it]


##### End evaluation!!
PEHE VAL = 8.03e+03


100%|██████████| 7/7 [00:00<00:00, 15.85it/s, avg_epoch_loss=0, epoch=450]
100%|██████████| 7/7 [00:00<00:00, 15.58it/s, avg_epoch_loss=0, epoch=451]
100%|██████████| 7/7 [00:00<00:00, 16.74it/s, avg_epoch_loss=0, epoch=452]
100%|██████████| 7/7 [00:00<00:00, 16.40it/s, avg_epoch_loss=0, epoch=453]
100%|██████████| 7/7 [00:00<00:00, 13.89it/s, avg_epoch_loss=0, epoch=454]
100%|██████████| 7/7 [00:00<00:00, 17.45it/s, avg_epoch_loss=0, epoch=455]
100%|██████████| 7/7 [00:00<00:00, 11.90it/s, avg_epoch_loss=0, epoch=456]
100%|██████████| 7/7 [00:00<00:00, 18.32it/s, avg_epoch_loss=0, epoch=457]
100%|██████████| 7/7 [00:00<00:00, 16.90it/s, avg_epoch_loss=0, epoch=458]
100%|██████████| 7/7 [00:00<00:00, 16.68it/s, avg_epoch_loss=0, epoch=459]
100%|██████████| 7/7 [00:00<00:00, 16.77it/s, avg_epoch_loss=0, epoch=460]
100%|██████████| 7/7 [00:00<00:00, 16.61it/s, avg_epoch_loss=0, epoch=461]
100%|██████████| 7/7 [00:00<00:00, 15.79it/s, avg_epoch_loss=0, epoch=462]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 499


100%|██████████| 2/2 [01:20<00:00, 40.11s/it]


##### End evaluation!!
PEHE VAL = 8.54e+03
Training complete.
Evaluating the model...


100%|██████████| 12/12 [14:00<00:00, 70.07s/it]
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name           | Type        | Params | Mode 
-------------------------------------------------------
0 | balancer       | Sequential  | 1.6 K  | train
1 | prior_encoder0 | Sequential  | 0      | train
2 | prior_encoder1 | Sequential  | 0      | train
3 | cond_mean0     | Sequential  | 701    | train
4 | cond_mean1     | Sequential  | 701    | train
5 | cond_std0      | Sequential  | 701    | train
6 | cond_std1      | Sequential  | 701    | train
7 | flows0         | DSFMarginal | 892    | train
8 | flows1         | DSFMarginal | 892    | train
-------------------------------------------------------
6.2 K     Trainable params
0         Non-trainable params
6.2 K     Total params
0.025     Total estimated model params size (MB)
45        Modules in train mode
0         Modules in eval mode


Fit and evaluate NOFLITE ...


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_steps=10000` reached.


Fit and evaluate BART ...


Multiprocess sampling (4 chains in 4 jobs)
PGBART: [mu]


Output()

Sampling 4 chains for 500 tune and 500 draw iterations (2_000 + 2_000 draws total) took 19 seconds.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details


X_extended (3000, 11)


Sampling: [mu, y]


Output()

Sampling: [mu, y]


Output()

Sampling: [mu, y]


Output()

X_extended (3000, 11)


Sampling: [mu, y]


Output()

Fit and evaluate CMGP ...
Fit and evaluate CEVAE ...
Using cpu


100%|██████████| 7000/7000 [00:35<00:00, 197.83it/s]


Fit and evaluate FCCN ...


100%|██████████| 20000/20000 [02:16<00:00, 146.77it/s]


Fit and evaluate GANITE ...


Training Counterfactual GAN: 100%|██████████| 10000/10000 [08:35<00:00, 19.41it/s]
Training ITE GAN: 100%|██████████| 10000/10000 [09:07<00:00, 18.28it/s]


Fit and evaluate DKLITE ...


  5%|▌         | 51/1000 [00:02<00:49, 19.02it/s]


Fit and evaluate diffpo ...


100%|██████████| 7/7 [00:00<00:00, 14.33it/s, avg_epoch_loss=0, epoch=0]
100%|██████████| 7/7 [00:00<00:00, 16.29it/s, avg_epoch_loss=0, epoch=1]
100%|██████████| 7/7 [00:00<00:00, 16.10it/s, avg_epoch_loss=0, epoch=2]
100%|██████████| 7/7 [00:00<00:00, 13.40it/s, avg_epoch_loss=0, epoch=3]
100%|██████████| 7/7 [00:00<00:00, 16.28it/s, avg_epoch_loss=0, epoch=4]
100%|██████████| 7/7 [00:00<00:00, 16.47it/s, avg_epoch_loss=0, epoch=5]
100%|██████████| 7/7 [00:00<00:00, 15.71it/s, avg_epoch_loss=0, epoch=6]
100%|██████████| 7/7 [00:00<00:00, 16.26it/s, avg_epoch_loss=0, epoch=7]
100%|██████████| 7/7 [00:00<00:00, 15.97it/s, avg_epoch_loss=0, epoch=8]
100%|██████████| 7/7 [00:00<00:00, 16.48it/s, avg_epoch_loss=0, epoch=9]
100%|██████████| 7/7 [00:00<00:00, 15.49it/s, avg_epoch_loss=0, epoch=10]
100%|██████████| 7/7 [00:00<00:00, 14.88it/s, avg_epoch_loss=0, epoch=11]
100%|██████████| 7/7 [00:00<00:00, 16.30it/s, avg_epoch_loss=0, epoch=12]
100%|██████████| 7/7 [00:00<00:00, 15.93it/s, av

Start validation!!!
Epoch: 49


100%|██████████| 2/2 [01:23<00:00, 41.84s/it]


##### End evaluation!!
PEHE VAL = 8.1e+03


100%|██████████| 7/7 [00:00<00:00, 16.35it/s, avg_epoch_loss=0, epoch=50]
100%|██████████| 7/7 [00:00<00:00, 16.16it/s, avg_epoch_loss=0, epoch=51]
100%|██████████| 7/7 [00:00<00:00, 15.94it/s, avg_epoch_loss=0, epoch=52]
100%|██████████| 7/7 [00:00<00:00, 16.15it/s, avg_epoch_loss=0, epoch=53]
100%|██████████| 7/7 [00:00<00:00, 16.12it/s, avg_epoch_loss=0, epoch=54]
100%|██████████| 7/7 [00:00<00:00, 16.38it/s, avg_epoch_loss=0, epoch=55]
100%|██████████| 7/7 [00:00<00:00, 14.84it/s, avg_epoch_loss=0, epoch=56]
100%|██████████| 7/7 [00:00<00:00, 16.37it/s, avg_epoch_loss=0, epoch=57]
100%|██████████| 7/7 [00:00<00:00, 16.20it/s, avg_epoch_loss=0, epoch=58]
100%|██████████| 7/7 [00:00<00:00, 16.22it/s, avg_epoch_loss=0, epoch=59]
100%|██████████| 7/7 [00:00<00:00, 12.89it/s, avg_epoch_loss=0, epoch=60]
100%|██████████| 7/7 [00:00<00:00, 14.93it/s, avg_epoch_loss=0, epoch=61]
100%|██████████| 7/7 [00:00<00:00, 15.99it/s, avg_epoch_loss=0, epoch=62]
100%|██████████| 7/7 [00:00<00:00, 16.

Start validation!!!
Epoch: 99


100%|██████████| 2/2 [01:26<00:00, 43.04s/it]


##### End evaluation!!
PEHE VAL = 8.29e+03


100%|██████████| 7/7 [00:00<00:00, 14.22it/s, avg_epoch_loss=0, epoch=100]
100%|██████████| 7/7 [00:00<00:00, 16.21it/s, avg_epoch_loss=0, epoch=101]
100%|██████████| 7/7 [00:00<00:00, 16.18it/s, avg_epoch_loss=0, epoch=102]
100%|██████████| 7/7 [00:00<00:00, 15.60it/s, avg_epoch_loss=0, epoch=103]
100%|██████████| 7/7 [00:00<00:00, 16.69it/s, avg_epoch_loss=0, epoch=104]
100%|██████████| 7/7 [00:00<00:00, 16.19it/s, avg_epoch_loss=0, epoch=105]
100%|██████████| 7/7 [00:00<00:00, 16.18it/s, avg_epoch_loss=0, epoch=106]
100%|██████████| 7/7 [00:00<00:00, 16.16it/s, avg_epoch_loss=0, epoch=107]
100%|██████████| 7/7 [00:00<00:00, 16.27it/s, avg_epoch_loss=0, epoch=108]
100%|██████████| 7/7 [00:00<00:00, 14.27it/s, avg_epoch_loss=0, epoch=109]
100%|██████████| 7/7 [00:00<00:00, 14.53it/s, avg_epoch_loss=0, epoch=110]
100%|██████████| 7/7 [00:00<00:00, 15.93it/s, avg_epoch_loss=0, epoch=111]
100%|██████████| 7/7 [00:00<00:00, 16.24it/s, avg_epoch_loss=0, epoch=112]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 149


100%|██████████| 2/2 [01:23<00:00, 41.63s/it]


##### End evaluation!!
PEHE VAL = 8.48e+03


100%|██████████| 7/7 [00:00<00:00, 16.37it/s, avg_epoch_loss=0, epoch=150]
100%|██████████| 7/7 [00:00<00:00, 16.38it/s, avg_epoch_loss=0, epoch=151]
100%|██████████| 7/7 [00:00<00:00, 16.06it/s, avg_epoch_loss=0, epoch=152]
100%|██████████| 7/7 [00:00<00:00, 16.65it/s, avg_epoch_loss=0, epoch=153]
100%|██████████| 7/7 [00:00<00:00, 15.51it/s, avg_epoch_loss=0, epoch=154]
100%|██████████| 7/7 [00:00<00:00, 13.99it/s, avg_epoch_loss=0, epoch=155]
100%|██████████| 7/7 [00:00<00:00, 15.30it/s, avg_epoch_loss=0, epoch=156]
100%|██████████| 7/7 [00:00<00:00, 16.31it/s, avg_epoch_loss=0, epoch=157]
100%|██████████| 7/7 [00:00<00:00, 16.45it/s, avg_epoch_loss=0, epoch=158]
100%|██████████| 7/7 [00:00<00:00, 15.77it/s, avg_epoch_loss=0, epoch=159]
100%|██████████| 7/7 [00:00<00:00, 14.48it/s, avg_epoch_loss=0, epoch=160]
100%|██████████| 7/7 [00:00<00:00, 16.13it/s, avg_epoch_loss=0, epoch=161]
100%|██████████| 7/7 [00:00<00:00, 17.03it/s, avg_epoch_loss=0, epoch=162]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 199


100%|██████████| 2/2 [01:20<00:00, 40.06s/it]


##### End evaluation!!
PEHE VAL = 8.64e+03


100%|██████████| 7/7 [00:00<00:00, 13.69it/s, avg_epoch_loss=0, epoch=200]
100%|██████████| 7/7 [00:00<00:00, 13.71it/s, avg_epoch_loss=0, epoch=201]
100%|██████████| 7/7 [00:00<00:00, 15.24it/s, avg_epoch_loss=0, epoch=202]
100%|██████████| 7/7 [00:00<00:00, 13.25it/s, avg_epoch_loss=0, epoch=203]
100%|██████████| 7/7 [00:00<00:00, 10.84it/s, avg_epoch_loss=0, epoch=204]
100%|██████████| 7/7 [00:00<00:00,  7.97it/s, avg_epoch_loss=0, epoch=205]
100%|██████████| 7/7 [00:00<00:00, 16.23it/s, avg_epoch_loss=0, epoch=206]
100%|██████████| 7/7 [00:00<00:00, 16.24it/s, avg_epoch_loss=0, epoch=207]
100%|██████████| 7/7 [00:00<00:00, 15.77it/s, avg_epoch_loss=0, epoch=208]
100%|██████████| 7/7 [00:00<00:00, 16.42it/s, avg_epoch_loss=0, epoch=209]
100%|██████████| 7/7 [00:00<00:00, 14.48it/s, avg_epoch_loss=0, epoch=210]
100%|██████████| 7/7 [00:00<00:00, 14.68it/s, avg_epoch_loss=0, epoch=211]
100%|██████████| 7/7 [00:00<00:00, 16.23it/s, avg_epoch_loss=0, epoch=212]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 249


100%|██████████| 2/2 [01:12<00:00, 36.02s/it]


##### End evaluation!!
PEHE VAL = 8.31e+03


100%|██████████| 7/7 [00:00<00:00, 16.00it/s, avg_epoch_loss=0, epoch=250]
100%|██████████| 7/7 [00:00<00:00, 14.07it/s, avg_epoch_loss=0, epoch=251]
100%|██████████| 7/7 [00:00<00:00, 13.17it/s, avg_epoch_loss=0, epoch=252]
100%|██████████| 7/7 [00:00<00:00, 10.86it/s, avg_epoch_loss=0, epoch=253]
100%|██████████| 7/7 [00:00<00:00, 10.48it/s, avg_epoch_loss=0, epoch=254]
100%|██████████| 7/7 [00:00<00:00, 12.16it/s, avg_epoch_loss=0, epoch=255]
100%|██████████| 7/7 [00:00<00:00, 13.58it/s, avg_epoch_loss=0, epoch=256]
100%|██████████| 7/7 [00:00<00:00, 17.01it/s, avg_epoch_loss=0, epoch=257]
100%|██████████| 7/7 [00:00<00:00, 15.90it/s, avg_epoch_loss=0, epoch=258]
100%|██████████| 7/7 [00:00<00:00, 16.10it/s, avg_epoch_loss=0, epoch=259]
100%|██████████| 7/7 [00:00<00:00, 14.29it/s, avg_epoch_loss=0, epoch=260]
100%|██████████| 7/7 [00:00<00:00, 14.51it/s, avg_epoch_loss=0, epoch=261]
100%|██████████| 7/7 [00:00<00:00, 15.31it/s, avg_epoch_loss=0, epoch=262]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 299


100%|██████████| 2/2 [01:11<00:00, 35.55s/it]


##### End evaluation!!
PEHE VAL = 8.37e+03


100%|██████████| 7/7 [00:00<00:00, 16.13it/s, avg_epoch_loss=0, epoch=300]
100%|██████████| 7/7 [00:00<00:00, 16.72it/s, avg_epoch_loss=0, epoch=301]
100%|██████████| 7/7 [00:00<00:00, 13.80it/s, avg_epoch_loss=0, epoch=302]
100%|██████████| 7/7 [00:00<00:00, 15.27it/s, avg_epoch_loss=0, epoch=303]
100%|██████████| 7/7 [00:00<00:00, 16.26it/s, avg_epoch_loss=0, epoch=304]
100%|██████████| 7/7 [00:00<00:00, 15.89it/s, avg_epoch_loss=0, epoch=305]
100%|██████████| 7/7 [00:00<00:00, 15.98it/s, avg_epoch_loss=0, epoch=306]
100%|██████████| 7/7 [00:00<00:00, 16.08it/s, avg_epoch_loss=0, epoch=307]
100%|██████████| 7/7 [00:00<00:00, 16.39it/s, avg_epoch_loss=0, epoch=308]
100%|██████████| 7/7 [00:00<00:00, 16.54it/s, avg_epoch_loss=0, epoch=309]
100%|██████████| 7/7 [00:00<00:00, 17.47it/s, avg_epoch_loss=0, epoch=310]
100%|██████████| 7/7 [00:00<00:00, 17.20it/s, avg_epoch_loss=0, epoch=311]
100%|██████████| 7/7 [00:00<00:00, 16.38it/s, avg_epoch_loss=0, epoch=312]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 349


100%|██████████| 2/2 [01:16<00:00, 38.23s/it]


##### End evaluation!!
PEHE VAL = 8.91e+03


100%|██████████| 7/7 [00:00<00:00, 16.86it/s, avg_epoch_loss=0, epoch=350]
100%|██████████| 7/7 [00:00<00:00, 16.79it/s, avg_epoch_loss=0, epoch=351]
100%|██████████| 7/7 [00:00<00:00, 18.39it/s, avg_epoch_loss=0, epoch=352]
100%|██████████| 7/7 [00:00<00:00, 16.11it/s, avg_epoch_loss=0, epoch=353]
100%|██████████| 7/7 [00:00<00:00, 17.12it/s, avg_epoch_loss=0, epoch=354]
100%|██████████| 7/7 [00:00<00:00, 16.44it/s, avg_epoch_loss=0, epoch=355]
100%|██████████| 7/7 [00:00<00:00, 16.72it/s, avg_epoch_loss=0, epoch=356]
100%|██████████| 7/7 [00:00<00:00, 17.29it/s, avg_epoch_loss=0, epoch=357]
100%|██████████| 7/7 [00:00<00:00, 15.37it/s, avg_epoch_loss=0, epoch=358]
100%|██████████| 7/7 [00:00<00:00, 12.72it/s, avg_epoch_loss=0, epoch=359]
100%|██████████| 7/7 [00:00<00:00, 15.81it/s, avg_epoch_loss=0, epoch=360]
100%|██████████| 7/7 [00:00<00:00, 16.40it/s, avg_epoch_loss=0, epoch=361]
100%|██████████| 7/7 [00:00<00:00, 16.63it/s, avg_epoch_loss=0, epoch=362]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 399


100%|██████████| 2/2 [01:16<00:00, 38.13s/it]


##### End evaluation!!
PEHE VAL = 8.9e+03


100%|██████████| 7/7 [00:00<00:00, 13.22it/s, avg_epoch_loss=0, epoch=400]
100%|██████████| 7/7 [00:00<00:00, 16.12it/s, avg_epoch_loss=0, epoch=401]
100%|██████████| 7/7 [00:00<00:00, 16.14it/s, avg_epoch_loss=0, epoch=402]
100%|██████████| 7/7 [00:00<00:00, 15.91it/s, avg_epoch_loss=0, epoch=403]
100%|██████████| 7/7 [00:00<00:00, 16.13it/s, avg_epoch_loss=0, epoch=404]
100%|██████████| 7/7 [00:00<00:00, 11.61it/s, avg_epoch_loss=0, epoch=405]
100%|██████████| 7/7 [00:00<00:00, 14.94it/s, avg_epoch_loss=0, epoch=406]
100%|██████████| 7/7 [00:00<00:00, 15.89it/s, avg_epoch_loss=0, epoch=407]
100%|██████████| 7/7 [00:00<00:00, 17.80it/s, avg_epoch_loss=0, epoch=408]
100%|██████████| 7/7 [00:00<00:00, 13.82it/s, avg_epoch_loss=0, epoch=409]
100%|██████████| 7/7 [00:00<00:00, 16.04it/s, avg_epoch_loss=0, epoch=410]
100%|██████████| 7/7 [00:00<00:00, 15.80it/s, avg_epoch_loss=0, epoch=411]
100%|██████████| 7/7 [00:00<00:00, 16.09it/s, avg_epoch_loss=0, epoch=412]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 449


100%|██████████| 2/2 [01:11<00:00, 35.71s/it]


##### End evaluation!!
PEHE VAL = 8.03e+03


100%|██████████| 7/7 [00:00<00:00, 14.84it/s, avg_epoch_loss=0, epoch=450]
100%|██████████| 7/7 [00:00<00:00, 17.61it/s, avg_epoch_loss=0, epoch=451]
100%|██████████| 7/7 [00:00<00:00, 18.10it/s, avg_epoch_loss=0, epoch=452]
100%|██████████| 7/7 [00:00<00:00, 15.34it/s, avg_epoch_loss=0, epoch=453]
100%|██████████| 7/7 [00:00<00:00, 17.41it/s, avg_epoch_loss=0, epoch=454]
100%|██████████| 7/7 [00:00<00:00, 18.77it/s, avg_epoch_loss=0, epoch=455]
100%|██████████| 7/7 [00:00<00:00, 18.81it/s, avg_epoch_loss=0, epoch=456]
100%|██████████| 7/7 [00:00<00:00, 18.60it/s, avg_epoch_loss=0, epoch=457]
100%|██████████| 7/7 [00:00<00:00, 19.02it/s, avg_epoch_loss=0, epoch=458]
100%|██████████| 7/7 [00:00<00:00, 18.88it/s, avg_epoch_loss=0, epoch=459]
100%|██████████| 7/7 [00:00<00:00, 18.67it/s, avg_epoch_loss=0, epoch=460]
100%|██████████| 7/7 [00:00<00:00, 18.38it/s, avg_epoch_loss=0, epoch=461]
100%|██████████| 7/7 [00:00<00:00, 18.59it/s, avg_epoch_loss=0, epoch=462]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 499


100%|██████████| 2/2 [01:12<00:00, 36.46s/it]


##### End evaluation!!
PEHE VAL = 8.54e+03
Training complete.
Evaluating the model...


100%|██████████| 12/12 [13:52<00:00, 69.36s/it]
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name           | Type        | Params | Mode 
-------------------------------------------------------
0 | balancer       | Sequential  | 1.6 K  | train
1 | prior_encoder0 | Sequential  | 0      | train
2 | prior_encoder1 | Sequential  | 0      | train
3 | cond_mean0     | Sequential  | 701    | train
4 | cond_mean1     | Sequential  | 701    | train
5 | cond_std0      | Sequential  | 701    | train
6 | cond_std1      | Sequential  | 701    | train
7 | flows0         | DSFMarginal | 892    | train
8 | flows1         | DSFMarginal | 892    | train
-------------------------------------------------------
6.2 K     Trainable params
0         Non-trainable params
6.2 K     Total params
0.025     Total estimated model params size (MB)
45        Modules in train mode
0         Modules in eval mode


Fit and evaluate NOFLITE ...


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_steps=10000` reached.


Fit and evaluate BART ...


Multiprocess sampling (4 chains in 4 jobs)
PGBART: [mu]


Output()

Sampling 4 chains for 500 tune and 500 draw iterations (2_000 + 2_000 draws total) took 18 seconds.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
Sampling: [mu, y]


Output()

X_extended (3000, 11)


Sampling: [mu, y]


Output()

Sampling: [mu, y]


Output()

X_extended (3000, 11)


Sampling: [mu, y]


Output()

Fit and evaluate CMGP ...
Fit and evaluate CEVAE ...
Using cpu


100%|██████████| 7000/7000 [00:31<00:00, 224.34it/s]


Fit and evaluate FCCN ...


100%|██████████| 20000/20000 [01:56<00:00, 171.17it/s]


Fit and evaluate GANITE ...


Training Counterfactual GAN: 100%|██████████| 10000/10000 [08:18<00:00, 20.08it/s]
Training ITE GAN: 100%|██████████| 10000/10000 [09:20<00:00, 17.84it/s]


Fit and evaluate DKLITE ...


  5%|▌         | 51/1000 [00:02<00:50, 18.75it/s]


Fit and evaluate diffpo ...


100%|██████████| 7/7 [00:00<00:00, 15.78it/s, avg_epoch_loss=0, epoch=0]
100%|██████████| 7/7 [00:00<00:00, 17.28it/s, avg_epoch_loss=0, epoch=1]
100%|██████████| 7/7 [00:00<00:00, 17.45it/s, avg_epoch_loss=0, epoch=2]
100%|██████████| 7/7 [00:00<00:00, 15.57it/s, avg_epoch_loss=0, epoch=3]
100%|██████████| 7/7 [00:00<00:00, 15.26it/s, avg_epoch_loss=0, epoch=4]
100%|██████████| 7/7 [00:00<00:00, 16.17it/s, avg_epoch_loss=0, epoch=5]
100%|██████████| 7/7 [00:00<00:00, 17.75it/s, avg_epoch_loss=0, epoch=6]
100%|██████████| 7/7 [00:00<00:00, 14.51it/s, avg_epoch_loss=0, epoch=7]
100%|██████████| 7/7 [00:00<00:00, 16.98it/s, avg_epoch_loss=0, epoch=8]
100%|██████████| 7/7 [00:00<00:00, 16.45it/s, avg_epoch_loss=0, epoch=9]
100%|██████████| 7/7 [00:00<00:00, 15.87it/s, avg_epoch_loss=0, epoch=10]
100%|██████████| 7/7 [00:00<00:00, 17.18it/s, avg_epoch_loss=0, epoch=11]
100%|██████████| 7/7 [00:00<00:00, 17.98it/s, avg_epoch_loss=0, epoch=12]
100%|██████████| 7/7 [00:00<00:00, 17.74it/s, av

Start validation!!!
Epoch: 49


100%|██████████| 2/2 [00:59<00:00, 29.57s/it]


##### End evaluation!!
PEHE VAL = 8.1e+03


100%|██████████| 7/7 [00:00<00:00, 17.44it/s, avg_epoch_loss=0, epoch=50]
100%|██████████| 7/7 [00:00<00:00, 17.51it/s, avg_epoch_loss=0, epoch=51]
100%|██████████| 7/7 [00:00<00:00, 17.71it/s, avg_epoch_loss=0, epoch=52]
100%|██████████| 7/7 [00:00<00:00, 16.36it/s, avg_epoch_loss=0, epoch=53]
100%|██████████| 7/7 [00:00<00:00, 16.80it/s, avg_epoch_loss=0, epoch=54]
100%|██████████| 7/7 [00:00<00:00, 17.47it/s, avg_epoch_loss=0, epoch=55]
100%|██████████| 7/7 [00:00<00:00, 17.07it/s, avg_epoch_loss=0, epoch=56]
100%|██████████| 7/7 [00:00<00:00, 16.72it/s, avg_epoch_loss=0, epoch=57]
100%|██████████| 7/7 [00:00<00:00, 17.08it/s, avg_epoch_loss=0, epoch=58]
100%|██████████| 7/7 [00:00<00:00, 17.54it/s, avg_epoch_loss=0, epoch=59]
100%|██████████| 7/7 [00:00<00:00, 17.11it/s, avg_epoch_loss=0, epoch=60]
100%|██████████| 7/7 [00:00<00:00, 16.46it/s, avg_epoch_loss=0, epoch=61]
100%|██████████| 7/7 [00:00<00:00, 15.01it/s, avg_epoch_loss=0, epoch=62]
100%|██████████| 7/7 [00:00<00:00, 18.

Start validation!!!
Epoch: 99


100%|██████████| 2/2 [00:58<00:00, 29.01s/it]


##### End evaluation!!
PEHE VAL = 8.29e+03


100%|██████████| 7/7 [00:00<00:00, 16.96it/s, avg_epoch_loss=0, epoch=100]
100%|██████████| 7/7 [00:00<00:00, 17.60it/s, avg_epoch_loss=0, epoch=101]
100%|██████████| 7/7 [00:00<00:00, 17.95it/s, avg_epoch_loss=0, epoch=102]
100%|██████████| 7/7 [00:00<00:00, 18.41it/s, avg_epoch_loss=0, epoch=103]
100%|██████████| 7/7 [00:00<00:00, 16.21it/s, avg_epoch_loss=0, epoch=104]
100%|██████████| 7/7 [00:00<00:00, 17.44it/s, avg_epoch_loss=0, epoch=105]
100%|██████████| 7/7 [00:00<00:00, 18.41it/s, avg_epoch_loss=0, epoch=106]
100%|██████████| 7/7 [00:00<00:00, 18.35it/s, avg_epoch_loss=0, epoch=107]
100%|██████████| 7/7 [00:00<00:00, 18.90it/s, avg_epoch_loss=0, epoch=108]
100%|██████████| 7/7 [00:00<00:00, 16.69it/s, avg_epoch_loss=0, epoch=109]
100%|██████████| 7/7 [00:00<00:00, 18.23it/s, avg_epoch_loss=0, epoch=110]
100%|██████████| 7/7 [00:00<00:00, 14.95it/s, avg_epoch_loss=0, epoch=111]
100%|██████████| 7/7 [00:00<00:00, 17.54it/s, avg_epoch_loss=0, epoch=112]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 149


100%|██████████| 2/2 [02:19<00:00, 69.52s/it]


##### End evaluation!!
PEHE VAL = 8.48e+03


100%|██████████| 7/7 [00:00<00:00, 15.82it/s, avg_epoch_loss=0, epoch=150]
100%|██████████| 7/7 [00:00<00:00, 16.93it/s, avg_epoch_loss=0, epoch=151]
100%|██████████| 7/7 [00:00<00:00, 18.21it/s, avg_epoch_loss=0, epoch=152]
100%|██████████| 7/7 [00:00<00:00, 18.19it/s, avg_epoch_loss=0, epoch=153]
100%|██████████| 7/7 [00:00<00:00, 17.28it/s, avg_epoch_loss=0, epoch=154]
100%|██████████| 7/7 [00:01<00:00,  5.33it/s, avg_epoch_loss=0, epoch=155]
100%|██████████| 7/7 [00:00<00:00, 13.37it/s, avg_epoch_loss=0, epoch=156]
100%|██████████| 7/7 [00:00<00:00, 17.96it/s, avg_epoch_loss=0, epoch=157]
100%|██████████| 7/7 [00:00<00:00, 19.01it/s, avg_epoch_loss=0, epoch=158]
100%|██████████| 7/7 [00:00<00:00, 19.29it/s, avg_epoch_loss=0, epoch=159]
100%|██████████| 7/7 [00:00<00:00, 18.52it/s, avg_epoch_loss=0, epoch=160]
100%|██████████| 7/7 [00:00<00:00, 17.22it/s, avg_epoch_loss=0, epoch=161]
100%|██████████| 7/7 [00:00<00:00, 17.59it/s, avg_epoch_loss=0, epoch=162]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 199


100%|██████████| 2/2 [20:55<00:00, 627.84s/it]


##### End evaluation!!
PEHE VAL = 8.64e+03


100%|██████████| 7/7 [00:00<00:00, 17.44it/s, avg_epoch_loss=0, epoch=200]
100%|██████████| 7/7 [00:00<00:00, 18.44it/s, avg_epoch_loss=0, epoch=201]
100%|██████████| 7/7 [00:00<00:00, 16.34it/s, avg_epoch_loss=0, epoch=202]
100%|██████████| 7/7 [00:00<00:00, 19.14it/s, avg_epoch_loss=0, epoch=203]
100%|██████████| 7/7 [00:00<00:00, 18.38it/s, avg_epoch_loss=0, epoch=204]
100%|██████████| 7/7 [00:00<00:00, 17.71it/s, avg_epoch_loss=0, epoch=205]
100%|██████████| 7/7 [00:00<00:00, 16.78it/s, avg_epoch_loss=0, epoch=206]
100%|██████████| 7/7 [00:00<00:00, 14.98it/s, avg_epoch_loss=0, epoch=207]
100%|██████████| 7/7 [00:00<00:00, 16.44it/s, avg_epoch_loss=0, epoch=208]
100%|██████████| 7/7 [00:00<00:00, 18.34it/s, avg_epoch_loss=0, epoch=209]
100%|██████████| 7/7 [00:00<00:00, 18.48it/s, avg_epoch_loss=0, epoch=210]
100%|██████████| 7/7 [00:00<00:00, 18.58it/s, avg_epoch_loss=0, epoch=211]
100%|██████████| 7/7 [00:00<00:00, 18.94it/s, avg_epoch_loss=0, epoch=212]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 249


100%|██████████| 2/2 [00:57<00:00, 28.93s/it]


##### End evaluation!!
PEHE VAL = 8.31e+03


100%|██████████| 7/7 [00:00<00:00, 17.91it/s, avg_epoch_loss=0, epoch=250]
100%|██████████| 7/7 [00:00<00:00, 17.70it/s, avg_epoch_loss=0, epoch=251]
100%|██████████| 7/7 [00:00<00:00, 15.03it/s, avg_epoch_loss=0, epoch=252]
100%|██████████| 7/7 [00:00<00:00, 17.28it/s, avg_epoch_loss=0, epoch=253]
100%|██████████| 7/7 [00:00<00:00, 15.62it/s, avg_epoch_loss=0, epoch=254]
100%|██████████| 7/7 [00:00<00:00, 17.19it/s, avg_epoch_loss=0, epoch=255]
100%|██████████| 7/7 [00:00<00:00, 19.03it/s, avg_epoch_loss=0, epoch=256]
100%|██████████| 7/7 [00:00<00:00, 16.83it/s, avg_epoch_loss=0, epoch=257]
100%|██████████| 7/7 [00:00<00:00, 16.14it/s, avg_epoch_loss=0, epoch=258]
100%|██████████| 7/7 [00:00<00:00, 17.20it/s, avg_epoch_loss=0, epoch=259]
100%|██████████| 7/7 [00:00<00:00, 16.86it/s, avg_epoch_loss=0, epoch=260]
100%|██████████| 7/7 [00:00<00:00, 18.17it/s, avg_epoch_loss=0, epoch=261]
100%|██████████| 7/7 [00:00<00:00, 18.40it/s, avg_epoch_loss=0, epoch=262]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 299


100%|██████████| 2/2 [01:00<00:00, 30.16s/it]


##### End evaluation!!
PEHE VAL = 8.37e+03


100%|██████████| 7/7 [00:00<00:00, 15.86it/s, avg_epoch_loss=0, epoch=300]
100%|██████████| 7/7 [00:00<00:00, 16.29it/s, avg_epoch_loss=0, epoch=301]
100%|██████████| 7/7 [00:00<00:00, 16.28it/s, avg_epoch_loss=0, epoch=302]
100%|██████████| 7/7 [00:00<00:00, 15.28it/s, avg_epoch_loss=0, epoch=303]
100%|██████████| 7/7 [00:00<00:00, 13.09it/s, avg_epoch_loss=0, epoch=304]
100%|██████████| 7/7 [00:00<00:00, 14.33it/s, avg_epoch_loss=0, epoch=305]
100%|██████████| 7/7 [00:00<00:00, 17.91it/s, avg_epoch_loss=0, epoch=306]
100%|██████████| 7/7 [00:00<00:00, 17.21it/s, avg_epoch_loss=0, epoch=307]
100%|██████████| 7/7 [00:00<00:00, 18.02it/s, avg_epoch_loss=0, epoch=308]
100%|██████████| 7/7 [00:00<00:00, 14.42it/s, avg_epoch_loss=0, epoch=309]
100%|██████████| 7/7 [00:00<00:00, 18.31it/s, avg_epoch_loss=0, epoch=310]
100%|██████████| 7/7 [00:00<00:00, 15.54it/s, avg_epoch_loss=0, epoch=311]
100%|██████████| 7/7 [00:00<00:00, 12.31it/s, avg_epoch_loss=0, epoch=312]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 349


100%|██████████| 2/2 [00:59<00:00, 29.73s/it]


##### End evaluation!!
PEHE VAL = 8.91e+03


100%|██████████| 7/7 [00:00<00:00, 14.89it/s, avg_epoch_loss=0, epoch=350]
100%|██████████| 7/7 [00:00<00:00, 17.00it/s, avg_epoch_loss=0, epoch=351]
100%|██████████| 7/7 [00:00<00:00, 17.03it/s, avg_epoch_loss=0, epoch=352]
100%|██████████| 7/7 [00:00<00:00, 18.17it/s, avg_epoch_loss=0, epoch=353]
100%|██████████| 7/7 [00:00<00:00, 18.09it/s, avg_epoch_loss=0, epoch=354]
100%|██████████| 7/7 [00:00<00:00, 14.74it/s, avg_epoch_loss=0, epoch=355]
100%|██████████| 7/7 [00:00<00:00, 18.11it/s, avg_epoch_loss=0, epoch=356]
100%|██████████| 7/7 [00:00<00:00, 18.27it/s, avg_epoch_loss=0, epoch=357]
100%|██████████| 7/7 [00:00<00:00, 18.12it/s, avg_epoch_loss=0, epoch=358]
100%|██████████| 7/7 [00:00<00:00, 18.79it/s, avg_epoch_loss=0, epoch=359]
100%|██████████| 7/7 [00:00<00:00, 19.18it/s, avg_epoch_loss=0, epoch=360]
100%|██████████| 7/7 [00:00<00:00, 19.01it/s, avg_epoch_loss=0, epoch=361]
100%|██████████| 7/7 [00:00<00:00, 18.86it/s, avg_epoch_loss=0, epoch=362]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 399


100%|██████████| 2/2 [00:59<00:00, 29.85s/it]


##### End evaluation!!
PEHE VAL = 8.9e+03


100%|██████████| 7/7 [00:00<00:00, 14.94it/s, avg_epoch_loss=0, epoch=400]
100%|██████████| 7/7 [00:00<00:00, 18.42it/s, avg_epoch_loss=0, epoch=401]
100%|██████████| 7/7 [00:00<00:00, 14.54it/s, avg_epoch_loss=0, epoch=402]
100%|██████████| 7/7 [00:00<00:00, 17.42it/s, avg_epoch_loss=0, epoch=403]
100%|██████████| 7/7 [00:00<00:00, 17.61it/s, avg_epoch_loss=0, epoch=404]
100%|██████████| 7/7 [00:00<00:00, 14.81it/s, avg_epoch_loss=0, epoch=405]
100%|██████████| 7/7 [00:00<00:00, 14.14it/s, avg_epoch_loss=0, epoch=406]
100%|██████████| 7/7 [00:00<00:00, 14.24it/s, avg_epoch_loss=0, epoch=407]
100%|██████████| 7/7 [00:00<00:00, 15.44it/s, avg_epoch_loss=0, epoch=408]
100%|██████████| 7/7 [00:00<00:00, 11.86it/s, avg_epoch_loss=0, epoch=409]
100%|██████████| 7/7 [00:00<00:00, 11.37it/s, avg_epoch_loss=0, epoch=410]
100%|██████████| 7/7 [00:00<00:00, 16.12it/s, avg_epoch_loss=0, epoch=411]
100%|██████████| 7/7 [00:00<00:00, 16.23it/s, avg_epoch_loss=0, epoch=412]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 449


100%|██████████| 2/2 [00:57<00:00, 28.91s/it]


##### End evaluation!!
PEHE VAL = 8.03e+03


100%|██████████| 7/7 [00:00<00:00, 16.82it/s, avg_epoch_loss=0, epoch=450]
100%|██████████| 7/7 [00:00<00:00, 16.90it/s, avg_epoch_loss=0, epoch=451]
100%|██████████| 7/7 [00:00<00:00, 18.09it/s, avg_epoch_loss=0, epoch=452]
100%|██████████| 7/7 [00:00<00:00, 19.08it/s, avg_epoch_loss=0, epoch=453]
100%|██████████| 7/7 [00:00<00:00, 18.81it/s, avg_epoch_loss=0, epoch=454]
100%|██████████| 7/7 [00:00<00:00, 18.05it/s, avg_epoch_loss=0, epoch=455]
100%|██████████| 7/7 [00:00<00:00, 16.42it/s, avg_epoch_loss=0, epoch=456]
100%|██████████| 7/7 [00:00<00:00, 18.36it/s, avg_epoch_loss=0, epoch=457]
100%|██████████| 7/7 [00:00<00:00, 18.62it/s, avg_epoch_loss=0, epoch=458]
100%|██████████| 7/7 [00:00<00:00, 18.43it/s, avg_epoch_loss=0, epoch=459]
100%|██████████| 7/7 [00:00<00:00, 18.34it/s, avg_epoch_loss=0, epoch=460]
100%|██████████| 7/7 [00:00<00:00, 18.84it/s, avg_epoch_loss=0, epoch=461]
100%|██████████| 7/7 [00:00<00:00, 18.70it/s, avg_epoch_loss=0, epoch=462]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 499


100%|██████████| 2/2 [00:58<00:00, 29.35s/it]


##### End evaluation!!
PEHE VAL = 8.54e+03
Training complete.
Evaluating the model...


100%|██████████| 12/12 [12:48<00:00, 64.07s/it]
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name           | Type        | Params | Mode 
-------------------------------------------------------
0 | balancer       | Sequential  | 1.6 K  | train
1 | prior_encoder0 | Sequential  | 0      | train
2 | prior_encoder1 | Sequential  | 0      | train
3 | cond_mean0     | Sequential  | 701    | train
4 | cond_mean1     | Sequential  | 701    | train
5 | cond_std0      | Sequential  | 701    | train
6 | cond_std1      | Sequential  | 701    | train
7 | flows0         | DSFMarginal | 892    | train
8 | flows1         | DSFMarginal | 892    | train
-------------------------------------------------------
6.2 K     Trainable params
0         Non-trainable params
6.2 K     Total params
0.025     Total estimated model params size (MB)
45        Modules in train mode
0         Modules in eval mode


Fit and evaluate NOFLITE ...


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_steps=10000` reached.


Fit and evaluate BART ...


Multiprocess sampling (4 chains in 4 jobs)
PGBART: [mu]


Output()

Sampling 4 chains for 500 tune and 500 draw iterations (2_000 + 2_000 draws total) took 19 seconds.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
Sampling: [mu, y]


Output()

X_extended (3000, 11)


Sampling: [mu, y]


Output()

Sampling: [mu, y]


Output()

X_extended (3000, 11)


Sampling: [mu, y]


Output()

Fit and evaluate CMGP ...
Fit and evaluate CEVAE ...
Using cpu


100%|██████████| 7000/7000 [00:32<00:00, 212.30it/s]


Fit and evaluate FCCN ...


100%|██████████| 20000/20000 [02:14<00:00, 148.42it/s]


Fit and evaluate GANITE ...


Training Counterfactual GAN: 100%|██████████| 10000/10000 [07:58<00:00, 20.91it/s]
Training ITE GAN: 100%|██████████| 10000/10000 [1:44:02<00:00,  1.60it/s]


Fit and evaluate DKLITE ...


  5%|▌         | 51/1000 [00:02<00:51, 18.50it/s]


Fit and evaluate diffpo ...


100%|██████████| 7/7 [00:00<00:00, 13.08it/s, avg_epoch_loss=0, epoch=0]
100%|██████████| 7/7 [00:00<00:00, 16.88it/s, avg_epoch_loss=0, epoch=1]
100%|██████████| 7/7 [00:00<00:00, 17.19it/s, avg_epoch_loss=0, epoch=2]
100%|██████████| 7/7 [00:00<00:00, 17.16it/s, avg_epoch_loss=0, epoch=3]
100%|██████████| 7/7 [00:00<00:00, 17.94it/s, avg_epoch_loss=0, epoch=4]
100%|██████████| 7/7 [00:00<00:00, 17.29it/s, avg_epoch_loss=0, epoch=5]
100%|██████████| 7/7 [00:00<00:00, 15.81it/s, avg_epoch_loss=0, epoch=6]
100%|██████████| 7/7 [00:00<00:00, 16.18it/s, avg_epoch_loss=0, epoch=7]
100%|██████████| 7/7 [00:00<00:00, 17.77it/s, avg_epoch_loss=0, epoch=8]
100%|██████████| 7/7 [00:00<00:00, 17.39it/s, avg_epoch_loss=0, epoch=9]
100%|██████████| 7/7 [00:00<00:00, 16.56it/s, avg_epoch_loss=0, epoch=10]
100%|██████████| 7/7 [00:00<00:00, 17.52it/s, avg_epoch_loss=0, epoch=11]
100%|██████████| 7/7 [00:00<00:00, 17.56it/s, avg_epoch_loss=0, epoch=12]
100%|██████████| 7/7 [00:00<00:00, 17.64it/s, av

Start validation!!!
Epoch: 49


100%|██████████| 2/2 [01:07<00:00, 33.51s/it]


##### End evaluation!!
PEHE VAL = 8.1e+03


100%|██████████| 7/7 [00:00<00:00, 16.09it/s, avg_epoch_loss=0, epoch=50]
100%|██████████| 7/7 [00:00<00:00, 15.32it/s, avg_epoch_loss=0, epoch=51]
100%|██████████| 7/7 [00:00<00:00, 16.27it/s, avg_epoch_loss=0, epoch=52]
100%|██████████| 7/7 [00:00<00:00, 16.88it/s, avg_epoch_loss=0, epoch=53]
100%|██████████| 7/7 [00:00<00:00, 15.30it/s, avg_epoch_loss=0, epoch=54]
100%|██████████| 7/7 [00:00<00:00, 17.49it/s, avg_epoch_loss=0, epoch=55]
100%|██████████| 7/7 [00:00<00:00, 17.17it/s, avg_epoch_loss=0, epoch=56]
100%|██████████| 7/7 [00:00<00:00, 15.75it/s, avg_epoch_loss=0, epoch=57]
100%|██████████| 7/7 [00:00<00:00, 17.43it/s, avg_epoch_loss=0, epoch=58]
100%|██████████| 7/7 [00:00<00:00, 17.30it/s, avg_epoch_loss=0, epoch=59]
100%|██████████| 7/7 [00:00<00:00, 17.31it/s, avg_epoch_loss=0, epoch=60]
100%|██████████| 7/7 [00:00<00:00, 16.63it/s, avg_epoch_loss=0, epoch=61]
100%|██████████| 7/7 [00:00<00:00, 17.38it/s, avg_epoch_loss=0, epoch=62]
100%|██████████| 7/7 [00:00<00:00, 16.

Start validation!!!
Epoch: 99


100%|██████████| 2/2 [01:04<00:00, 32.37s/it]


##### End evaluation!!
PEHE VAL = 8.29e+03


100%|██████████| 7/7 [00:00<00:00, 17.70it/s, avg_epoch_loss=0, epoch=100]
100%|██████████| 7/7 [00:00<00:00, 17.84it/s, avg_epoch_loss=0, epoch=101]
100%|██████████| 7/7 [00:00<00:00, 18.09it/s, avg_epoch_loss=0, epoch=102]
100%|██████████| 7/7 [00:00<00:00, 18.10it/s, avg_epoch_loss=0, epoch=103]
100%|██████████| 7/7 [00:00<00:00, 17.29it/s, avg_epoch_loss=0, epoch=104]
100%|██████████| 7/7 [00:00<00:00, 18.22it/s, avg_epoch_loss=0, epoch=105]
100%|██████████| 7/7 [00:00<00:00, 17.68it/s, avg_epoch_loss=0, epoch=106]
100%|██████████| 7/7 [00:00<00:00, 14.97it/s, avg_epoch_loss=0, epoch=107]
100%|██████████| 7/7 [00:00<00:00, 16.69it/s, avg_epoch_loss=0, epoch=108]
100%|██████████| 7/7 [00:00<00:00, 15.41it/s, avg_epoch_loss=0, epoch=109]
100%|██████████| 7/7 [00:00<00:00, 16.73it/s, avg_epoch_loss=0, epoch=110]
100%|██████████| 7/7 [00:00<00:00, 17.09it/s, avg_epoch_loss=0, epoch=111]
100%|██████████| 7/7 [00:00<00:00, 17.62it/s, avg_epoch_loss=0, epoch=112]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 149


100%|██████████| 2/2 [01:05<00:00, 32.61s/it]


##### End evaluation!!
PEHE VAL = 8.48e+03


100%|██████████| 7/7 [00:00<00:00, 14.94it/s, avg_epoch_loss=0, epoch=150]
100%|██████████| 7/7 [00:00<00:00, 16.37it/s, avg_epoch_loss=0, epoch=151]
100%|██████████| 7/7 [00:00<00:00, 16.70it/s, avg_epoch_loss=0, epoch=152]
100%|██████████| 7/7 [00:00<00:00, 15.25it/s, avg_epoch_loss=0, epoch=153]
100%|██████████| 7/7 [00:00<00:00, 14.86it/s, avg_epoch_loss=0, epoch=154]
100%|██████████| 7/7 [00:00<00:00, 17.51it/s, avg_epoch_loss=0, epoch=155]
100%|██████████| 7/7 [00:00<00:00, 17.48it/s, avg_epoch_loss=0, epoch=156]
100%|██████████| 7/7 [00:00<00:00, 17.07it/s, avg_epoch_loss=0, epoch=157]
100%|██████████| 7/7 [00:00<00:00, 16.93it/s, avg_epoch_loss=0, epoch=158]
100%|██████████| 7/7 [00:00<00:00, 16.92it/s, avg_epoch_loss=0, epoch=159]
100%|██████████| 7/7 [00:00<00:00, 17.30it/s, avg_epoch_loss=0, epoch=160]
100%|██████████| 7/7 [00:00<00:00, 16.41it/s, avg_epoch_loss=0, epoch=161]
100%|██████████| 7/7 [00:00<00:00, 15.89it/s, avg_epoch_loss=0, epoch=162]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 199


100%|██████████| 2/2 [01:04<00:00, 32.25s/it]


##### End evaluation!!
PEHE VAL = 8.64e+03


100%|██████████| 7/7 [00:00<00:00, 17.90it/s, avg_epoch_loss=0, epoch=200]
100%|██████████| 7/7 [00:00<00:00, 17.98it/s, avg_epoch_loss=0, epoch=201]
100%|██████████| 7/7 [00:00<00:00, 17.82it/s, avg_epoch_loss=0, epoch=202]
100%|██████████| 7/7 [00:00<00:00, 17.98it/s, avg_epoch_loss=0, epoch=203]
100%|██████████| 7/7 [00:00<00:00, 17.78it/s, avg_epoch_loss=0, epoch=204]
100%|██████████| 7/7 [00:00<00:00, 18.04it/s, avg_epoch_loss=0, epoch=205]
100%|██████████| 7/7 [00:00<00:00, 12.99it/s, avg_epoch_loss=0, epoch=206]
100%|██████████| 7/7 [00:00<00:00, 14.43it/s, avg_epoch_loss=0, epoch=207]
100%|██████████| 7/7 [00:00<00:00, 18.26it/s, avg_epoch_loss=0, epoch=208]
100%|██████████| 7/7 [00:00<00:00, 17.98it/s, avg_epoch_loss=0, epoch=209]
100%|██████████| 7/7 [00:00<00:00, 18.17it/s, avg_epoch_loss=0, epoch=210]
100%|██████████| 7/7 [00:00<00:00, 17.64it/s, avg_epoch_loss=0, epoch=211]
100%|██████████| 7/7 [00:00<00:00, 18.04it/s, avg_epoch_loss=0, epoch=212]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 249


100%|██████████| 2/2 [01:02<00:00, 31.11s/it]


##### End evaluation!!
PEHE VAL = 8.31e+03


100%|██████████| 7/7 [00:00<00:00, 16.97it/s, avg_epoch_loss=0, epoch=250]
100%|██████████| 7/7 [00:00<00:00, 17.90it/s, avg_epoch_loss=0, epoch=251]
100%|██████████| 7/7 [00:00<00:00, 14.75it/s, avg_epoch_loss=0, epoch=252]
100%|██████████| 7/7 [00:00<00:00, 17.26it/s, avg_epoch_loss=0, epoch=253]
100%|██████████| 7/7 [00:00<00:00, 17.83it/s, avg_epoch_loss=0, epoch=254]
100%|██████████| 7/7 [00:00<00:00, 17.63it/s, avg_epoch_loss=0, epoch=255]
100%|██████████| 7/7 [00:00<00:00, 17.57it/s, avg_epoch_loss=0, epoch=256]
100%|██████████| 7/7 [00:00<00:00, 15.86it/s, avg_epoch_loss=0, epoch=257]
100%|██████████| 7/7 [00:00<00:00, 15.77it/s, avg_epoch_loss=0, epoch=258]
100%|██████████| 7/7 [00:00<00:00, 15.43it/s, avg_epoch_loss=0, epoch=259]
100%|██████████| 7/7 [00:00<00:00, 16.30it/s, avg_epoch_loss=0, epoch=260]
100%|██████████| 7/7 [00:00<00:00, 15.55it/s, avg_epoch_loss=0, epoch=261]
100%|██████████| 7/7 [00:00<00:00, 18.32it/s, avg_epoch_loss=0, epoch=262]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 299


100%|██████████| 2/2 [00:58<00:00, 29.18s/it]


##### End evaluation!!
PEHE VAL = 8.37e+03


100%|██████████| 7/7 [00:00<00:00, 18.14it/s, avg_epoch_loss=0, epoch=300]
100%|██████████| 7/7 [00:00<00:00, 18.09it/s, avg_epoch_loss=0, epoch=301]
100%|██████████| 7/7 [00:00<00:00, 18.08it/s, avg_epoch_loss=0, epoch=302]
100%|██████████| 7/7 [00:00<00:00, 15.26it/s, avg_epoch_loss=0, epoch=303]
100%|██████████| 7/7 [00:00<00:00, 18.56it/s, avg_epoch_loss=0, epoch=304]
100%|██████████| 7/7 [00:00<00:00, 17.93it/s, avg_epoch_loss=0, epoch=305]
100%|██████████| 7/7 [00:00<00:00, 18.30it/s, avg_epoch_loss=0, epoch=306]
100%|██████████| 7/7 [00:00<00:00, 18.33it/s, avg_epoch_loss=0, epoch=307]
100%|██████████| 7/7 [00:00<00:00, 18.17it/s, avg_epoch_loss=0, epoch=308]
100%|██████████| 7/7 [00:00<00:00, 18.41it/s, avg_epoch_loss=0, epoch=309]
100%|██████████| 7/7 [00:00<00:00, 18.05it/s, avg_epoch_loss=0, epoch=310]
100%|██████████| 7/7 [00:00<00:00, 18.30it/s, avg_epoch_loss=0, epoch=311]
100%|██████████| 7/7 [00:00<00:00, 18.37it/s, avg_epoch_loss=0, epoch=312]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 349


100%|██████████| 2/2 [00:59<00:00, 29.69s/it]


##### End evaluation!!
PEHE VAL = 8.91e+03


100%|██████████| 7/7 [00:00<00:00, 16.75it/s, avg_epoch_loss=0, epoch=350]
100%|██████████| 7/7 [00:00<00:00, 18.76it/s, avg_epoch_loss=0, epoch=351]
100%|██████████| 7/7 [00:00<00:00, 14.45it/s, avg_epoch_loss=0, epoch=352]
100%|██████████| 7/7 [00:00<00:00, 17.97it/s, avg_epoch_loss=0, epoch=353]
100%|██████████| 7/7 [00:00<00:00, 18.48it/s, avg_epoch_loss=0, epoch=354]
100%|██████████| 7/7 [00:00<00:00, 15.88it/s, avg_epoch_loss=0, epoch=355]
100%|██████████| 7/7 [00:00<00:00, 12.75it/s, avg_epoch_loss=0, epoch=356]
100%|██████████| 7/7 [00:00<00:00, 17.58it/s, avg_epoch_loss=0, epoch=357]
100%|██████████| 7/7 [00:00<00:00, 18.15it/s, avg_epoch_loss=0, epoch=358]
100%|██████████| 7/7 [00:00<00:00, 17.98it/s, avg_epoch_loss=0, epoch=359]
100%|██████████| 7/7 [00:00<00:00, 18.07it/s, avg_epoch_loss=0, epoch=360]
100%|██████████| 7/7 [00:00<00:00, 17.92it/s, avg_epoch_loss=0, epoch=361]
100%|██████████| 7/7 [00:00<00:00, 17.68it/s, avg_epoch_loss=0, epoch=362]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 399


100%|██████████| 2/2 [01:01<00:00, 30.93s/it]


##### End evaluation!!
PEHE VAL = 8.9e+03


100%|██████████| 7/7 [00:00<00:00, 18.06it/s, avg_epoch_loss=0, epoch=400]
100%|██████████| 7/7 [00:00<00:00, 18.07it/s, avg_epoch_loss=0, epoch=401]
100%|██████████| 7/7 [00:00<00:00, 18.00it/s, avg_epoch_loss=0, epoch=402]
100%|██████████| 7/7 [00:00<00:00, 18.28it/s, avg_epoch_loss=0, epoch=403]
100%|██████████| 7/7 [00:00<00:00, 14.03it/s, avg_epoch_loss=0, epoch=404]
100%|██████████| 7/7 [00:00<00:00, 15.00it/s, avg_epoch_loss=0, epoch=405]
100%|██████████| 7/7 [00:00<00:00, 17.60it/s, avg_epoch_loss=0, epoch=406]
100%|██████████| 7/7 [00:00<00:00, 17.53it/s, avg_epoch_loss=0, epoch=407]
100%|██████████| 7/7 [00:00<00:00, 18.14it/s, avg_epoch_loss=0, epoch=408]
100%|██████████| 7/7 [00:00<00:00, 16.28it/s, avg_epoch_loss=0, epoch=409]
100%|██████████| 7/7 [00:00<00:00, 16.55it/s, avg_epoch_loss=0, epoch=410]
100%|██████████| 7/7 [00:00<00:00, 16.66it/s, avg_epoch_loss=0, epoch=411]
100%|██████████| 7/7 [00:00<00:00, 17.38it/s, avg_epoch_loss=0, epoch=412]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 449


100%|██████████| 2/2 [01:01<00:00, 30.66s/it]


##### End evaluation!!
PEHE VAL = 8.03e+03


100%|██████████| 7/7 [00:00<00:00, 14.47it/s, avg_epoch_loss=0, epoch=450]
100%|██████████| 7/7 [00:00<00:00, 17.02it/s, avg_epoch_loss=0, epoch=451]
100%|██████████| 7/7 [00:00<00:00, 17.19it/s, avg_epoch_loss=0, epoch=452]
100%|██████████| 7/7 [00:00<00:00, 17.14it/s, avg_epoch_loss=0, epoch=453]
100%|██████████| 7/7 [00:00<00:00, 17.10it/s, avg_epoch_loss=0, epoch=454]
100%|██████████| 7/7 [00:00<00:00, 17.89it/s, avg_epoch_loss=0, epoch=455]
100%|██████████| 7/7 [00:00<00:00, 18.18it/s, avg_epoch_loss=0, epoch=456]
100%|██████████| 7/7 [00:00<00:00, 18.47it/s, avg_epoch_loss=0, epoch=457]
100%|██████████| 7/7 [00:00<00:00, 18.16it/s, avg_epoch_loss=0, epoch=458]
100%|██████████| 7/7 [00:00<00:00, 13.74it/s, avg_epoch_loss=0, epoch=459]
100%|██████████| 7/7 [00:00<00:00, 16.89it/s, avg_epoch_loss=0, epoch=460]
100%|██████████| 7/7 [00:00<00:00, 18.22it/s, avg_epoch_loss=0, epoch=461]
100%|██████████| 7/7 [00:00<00:00, 18.33it/s, avg_epoch_loss=0, epoch=462]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 499


100%|██████████| 2/2 [01:04<00:00, 32.17s/it]


##### End evaluation!!
PEHE VAL = 8.54e+03
Training complete.
Evaluating the model...


100%|██████████| 12/12 [14:48<00:00, 74.06s/it]
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name           | Type        | Params | Mode 
-------------------------------------------------------
0 | balancer       | Sequential  | 1.6 K  | train
1 | prior_encoder0 | Sequential  | 0      | train
2 | prior_encoder1 | Sequential  | 0      | train
3 | cond_mean0     | Sequential  | 701    | train
4 | cond_mean1     | Sequential  | 701    | train
5 | cond_std0      | Sequential  | 701    | train
6 | cond_std1      | Sequential  | 701    | train
7 | flows0         | DSFMarginal | 892    | train
8 | flows1         | DSFMarginal | 892    | train
-------------------------------------------------------
6.2 K     Trainable params
0         Non-trainable params
6.2 K     Total params
0.025     Total estimated model params size (MB)
45        Modules in train mode
0         Modules in eval mode


Fit and evaluate NOFLITE ...


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_steps=10000` reached.


Fit and evaluate BART ...


Multiprocess sampling (4 chains in 4 jobs)
PGBART: [mu]


Output()

Sampling 4 chains for 500 tune and 500 draw iterations (2_000 + 2_000 draws total) took 20 seconds.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
Sampling: [mu, y]


Output()

X_extended (3000, 11)


Sampling: [mu, y]


Output()

Sampling: [mu, y]


Output()

X_extended (3000, 11)


Sampling: [mu, y]


Output()

Fit and evaluate CMGP ...
Fit and evaluate CEVAE ...
Using cpu


100%|██████████| 7000/7000 [00:31<00:00, 219.78it/s]


Fit and evaluate FCCN ...


100%|██████████| 20000/20000 [01:58<00:00, 168.58it/s]


Fit and evaluate GANITE ...


Training Counterfactual GAN: 100%|██████████| 10000/10000 [11:19<00:00, 14.72it/s]
Training ITE GAN: 100%|██████████| 10000/10000 [09:15<00:00, 18.01it/s]


Fit and evaluate DKLITE ...


  5%|▌         | 51/1000 [00:02<00:48, 19.74it/s]


Fit and evaluate diffpo ...


100%|██████████| 7/7 [00:00<00:00, 17.12it/s, avg_epoch_loss=0, epoch=0]
100%|██████████| 7/7 [00:00<00:00, 18.50it/s, avg_epoch_loss=0, epoch=1]
100%|██████████| 7/7 [00:00<00:00, 18.48it/s, avg_epoch_loss=0, epoch=2]
100%|██████████| 7/7 [00:00<00:00, 17.80it/s, avg_epoch_loss=0, epoch=3]
100%|██████████| 7/7 [00:00<00:00, 15.84it/s, avg_epoch_loss=0, epoch=4]
100%|██████████| 7/7 [00:00<00:00, 16.45it/s, avg_epoch_loss=0, epoch=5]
100%|██████████| 7/7 [00:00<00:00, 17.90it/s, avg_epoch_loss=0, epoch=6]
100%|██████████| 7/7 [00:00<00:00, 16.60it/s, avg_epoch_loss=0, epoch=7]
100%|██████████| 7/7 [00:00<00:00, 17.85it/s, avg_epoch_loss=0, epoch=8]
100%|██████████| 7/7 [00:00<00:00, 17.25it/s, avg_epoch_loss=0, epoch=9]
100%|██████████| 7/7 [00:00<00:00, 17.42it/s, avg_epoch_loss=0, epoch=10]
100%|██████████| 7/7 [00:00<00:00, 17.41it/s, avg_epoch_loss=0, epoch=11]
100%|██████████| 7/7 [00:00<00:00, 17.15it/s, avg_epoch_loss=0, epoch=12]
100%|██████████| 7/7 [00:00<00:00, 17.91it/s, av

Start validation!!!
Epoch: 49


100%|██████████| 2/2 [00:59<00:00, 29.61s/it]


##### End evaluation!!
PEHE VAL = 8.1e+03


100%|██████████| 7/7 [00:00<00:00, 16.38it/s, avg_epoch_loss=0, epoch=50]
100%|██████████| 7/7 [00:00<00:00, 15.98it/s, avg_epoch_loss=0, epoch=51]
100%|██████████| 7/7 [00:00<00:00, 17.88it/s, avg_epoch_loss=0, epoch=52]
100%|██████████| 7/7 [00:00<00:00, 18.09it/s, avg_epoch_loss=0, epoch=53]
100%|██████████| 7/7 [00:00<00:00, 18.07it/s, avg_epoch_loss=0, epoch=54]
100%|██████████| 7/7 [00:00<00:00, 18.02it/s, avg_epoch_loss=0, epoch=55]
100%|██████████| 7/7 [00:00<00:00, 18.44it/s, avg_epoch_loss=0, epoch=56]
100%|██████████| 7/7 [00:00<00:00, 17.78it/s, avg_epoch_loss=0, epoch=57]
100%|██████████| 7/7 [00:00<00:00, 17.81it/s, avg_epoch_loss=0, epoch=58]
100%|██████████| 7/7 [00:00<00:00, 17.94it/s, avg_epoch_loss=0, epoch=59]
100%|██████████| 7/7 [00:00<00:00, 17.19it/s, avg_epoch_loss=0, epoch=60]
100%|██████████| 7/7 [00:00<00:00, 17.71it/s, avg_epoch_loss=0, epoch=61]
100%|██████████| 7/7 [00:00<00:00, 17.47it/s, avg_epoch_loss=0, epoch=62]
100%|██████████| 7/7 [00:00<00:00, 17.

Start validation!!!
Epoch: 99


100%|██████████| 2/2 [01:07<00:00, 33.60s/it]


##### End evaluation!!
PEHE VAL = 8.29e+03


100%|██████████| 7/7 [00:00<00:00, 16.44it/s, avg_epoch_loss=0, epoch=100]
100%|██████████| 7/7 [00:00<00:00, 18.15it/s, avg_epoch_loss=0, epoch=101]
100%|██████████| 7/7 [00:00<00:00, 17.51it/s, avg_epoch_loss=0, epoch=102]
100%|██████████| 7/7 [00:00<00:00, 18.84it/s, avg_epoch_loss=0, epoch=103]
100%|██████████| 7/7 [00:00<00:00, 18.24it/s, avg_epoch_loss=0, epoch=104]
100%|██████████| 7/7 [00:00<00:00, 20.00it/s, avg_epoch_loss=0, epoch=105]
100%|██████████| 7/7 [00:00<00:00, 14.24it/s, avg_epoch_loss=0, epoch=106]
100%|██████████| 7/7 [00:00<00:00, 12.32it/s, avg_epoch_loss=0, epoch=107]
100%|██████████| 7/7 [00:00<00:00,  8.78it/s, avg_epoch_loss=0, epoch=108]
100%|██████████| 7/7 [00:00<00:00, 15.55it/s, avg_epoch_loss=0, epoch=109]
100%|██████████| 7/7 [00:00<00:00, 12.87it/s, avg_epoch_loss=0, epoch=110]
100%|██████████| 7/7 [00:00<00:00, 13.16it/s, avg_epoch_loss=0, epoch=111]
100%|██████████| 7/7 [00:00<00:00, 18.01it/s, avg_epoch_loss=0, epoch=112]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 149


100%|██████████| 2/2 [01:07<00:00, 33.80s/it]


##### End evaluation!!
PEHE VAL = 8.48e+03


100%|██████████| 7/7 [00:00<00:00, 15.26it/s, avg_epoch_loss=0, epoch=150]
100%|██████████| 7/7 [00:00<00:00, 16.89it/s, avg_epoch_loss=0, epoch=151]
100%|██████████| 7/7 [00:00<00:00, 18.64it/s, avg_epoch_loss=0, epoch=152]
100%|██████████| 7/7 [00:00<00:00, 18.75it/s, avg_epoch_loss=0, epoch=153]
100%|██████████| 7/7 [00:00<00:00, 19.00it/s, avg_epoch_loss=0, epoch=154]
100%|██████████| 7/7 [00:00<00:00, 14.03it/s, avg_epoch_loss=0, epoch=155]
100%|██████████| 7/7 [00:00<00:00, 17.22it/s, avg_epoch_loss=0, epoch=156]
100%|██████████| 7/7 [00:00<00:00, 15.75it/s, avg_epoch_loss=0, epoch=157]
100%|██████████| 7/7 [00:00<00:00, 17.75it/s, avg_epoch_loss=0, epoch=158]
100%|██████████| 7/7 [00:00<00:00, 18.36it/s, avg_epoch_loss=0, epoch=159]
100%|██████████| 7/7 [00:00<00:00, 15.03it/s, avg_epoch_loss=0, epoch=160]
100%|██████████| 7/7 [00:00<00:00, 16.69it/s, avg_epoch_loss=0, epoch=161]
100%|██████████| 7/7 [00:00<00:00, 18.77it/s, avg_epoch_loss=0, epoch=162]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 199


100%|██████████| 2/2 [01:08<00:00, 34.46s/it]


##### End evaluation!!
PEHE VAL = 8.64e+03


100%|██████████| 7/7 [00:00<00:00, 16.82it/s, avg_epoch_loss=0, epoch=200]
100%|██████████| 7/7 [00:00<00:00, 16.83it/s, avg_epoch_loss=0, epoch=201]
100%|██████████| 7/7 [00:00<00:00, 18.87it/s, avg_epoch_loss=0, epoch=202]
100%|██████████| 7/7 [00:00<00:00, 18.07it/s, avg_epoch_loss=0, epoch=203]
100%|██████████| 7/7 [00:00<00:00,  8.61it/s, avg_epoch_loss=0, epoch=204]
100%|██████████| 7/7 [00:01<00:00,  6.36it/s, avg_epoch_loss=0, epoch=205]
100%|██████████| 7/7 [00:00<00:00,  9.81it/s, avg_epoch_loss=0, epoch=206]
100%|██████████| 7/7 [00:00<00:00, 12.12it/s, avg_epoch_loss=0, epoch=207]
100%|██████████| 7/7 [00:00<00:00, 16.18it/s, avg_epoch_loss=0, epoch=208]
100%|██████████| 7/7 [00:00<00:00, 16.65it/s, avg_epoch_loss=0, epoch=209]
100%|██████████| 7/7 [00:00<00:00, 15.06it/s, avg_epoch_loss=0, epoch=210]
100%|██████████| 7/7 [00:00<00:00, 13.70it/s, avg_epoch_loss=0, epoch=211]
100%|██████████| 7/7 [00:00<00:00, 13.91it/s, avg_epoch_loss=0, epoch=212]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 249


100%|██████████| 2/2 [01:09<00:00, 34.77s/it]


##### End evaluation!!
PEHE VAL = 8.31e+03


100%|██████████| 7/7 [00:00<00:00, 17.56it/s, avg_epoch_loss=0, epoch=250]
100%|██████████| 7/7 [00:00<00:00, 15.70it/s, avg_epoch_loss=0, epoch=251]
100%|██████████| 7/7 [00:00<00:00, 18.05it/s, avg_epoch_loss=0, epoch=252]
100%|██████████| 7/7 [00:00<00:00, 13.28it/s, avg_epoch_loss=0, epoch=253]
100%|██████████| 7/7 [00:00<00:00, 17.61it/s, avg_epoch_loss=0, epoch=254]
100%|██████████| 7/7 [00:00<00:00, 17.66it/s, avg_epoch_loss=0, epoch=255]
100%|██████████| 7/7 [00:00<00:00, 17.48it/s, avg_epoch_loss=0, epoch=256]
100%|██████████| 7/7 [00:00<00:00, 17.67it/s, avg_epoch_loss=0, epoch=257]
100%|██████████| 7/7 [00:00<00:00, 17.36it/s, avg_epoch_loss=0, epoch=258]
100%|██████████| 7/7 [00:00<00:00, 17.91it/s, avg_epoch_loss=0, epoch=259]
100%|██████████| 7/7 [00:00<00:00, 17.36it/s, avg_epoch_loss=0, epoch=260]
100%|██████████| 7/7 [00:00<00:00, 17.49it/s, avg_epoch_loss=0, epoch=261]
100%|██████████| 7/7 [00:00<00:00, 16.59it/s, avg_epoch_loss=0, epoch=262]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 299


100%|██████████| 2/2 [01:09<00:00, 34.57s/it]


##### End evaluation!!
PEHE VAL = 8.37e+03


100%|██████████| 7/7 [00:00<00:00, 16.31it/s, avg_epoch_loss=0, epoch=300]
100%|██████████| 7/7 [00:00<00:00, 17.39it/s, avg_epoch_loss=0, epoch=301]
100%|██████████| 7/7 [00:00<00:00, 16.88it/s, avg_epoch_loss=0, epoch=302]
100%|██████████| 7/7 [00:00<00:00, 17.91it/s, avg_epoch_loss=0, epoch=303]
100%|██████████| 7/7 [00:00<00:00, 18.48it/s, avg_epoch_loss=0, epoch=304]
100%|██████████| 7/7 [00:00<00:00, 15.03it/s, avg_epoch_loss=0, epoch=305]
100%|██████████| 7/7 [00:00<00:00, 17.68it/s, avg_epoch_loss=0, epoch=306]
100%|██████████| 7/7 [00:00<00:00, 18.71it/s, avg_epoch_loss=0, epoch=307]
100%|██████████| 7/7 [00:00<00:00, 15.17it/s, avg_epoch_loss=0, epoch=308]
100%|██████████| 7/7 [00:00<00:00, 17.86it/s, avg_epoch_loss=0, epoch=309]
100%|██████████| 7/7 [00:00<00:00, 17.97it/s, avg_epoch_loss=0, epoch=310]
100%|██████████| 7/7 [00:00<00:00, 18.15it/s, avg_epoch_loss=0, epoch=311]
100%|██████████| 7/7 [00:00<00:00, 15.63it/s, avg_epoch_loss=0, epoch=312]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 349


100%|██████████| 2/2 [01:06<00:00, 33.29s/it]


##### End evaluation!!
PEHE VAL = 8.91e+03


100%|██████████| 7/7 [00:00<00:00, 18.54it/s, avg_epoch_loss=0, epoch=350]
100%|██████████| 7/7 [00:00<00:00, 18.99it/s, avg_epoch_loss=0, epoch=351]
100%|██████████| 7/7 [00:00<00:00, 17.97it/s, avg_epoch_loss=0, epoch=352]
100%|██████████| 7/7 [00:00<00:00, 17.97it/s, avg_epoch_loss=0, epoch=353]
100%|██████████| 7/7 [00:00<00:00, 16.07it/s, avg_epoch_loss=0, epoch=354]
100%|██████████| 7/7 [00:00<00:00, 18.40it/s, avg_epoch_loss=0, epoch=355]
100%|██████████| 7/7 [00:00<00:00, 18.36it/s, avg_epoch_loss=0, epoch=356]
100%|██████████| 7/7 [00:00<00:00, 18.12it/s, avg_epoch_loss=0, epoch=357]
100%|██████████| 7/7 [00:00<00:00, 18.39it/s, avg_epoch_loss=0, epoch=358]
100%|██████████| 7/7 [00:00<00:00, 18.80it/s, avg_epoch_loss=0, epoch=359]
100%|██████████| 7/7 [00:00<00:00, 18.81it/s, avg_epoch_loss=0, epoch=360]
100%|██████████| 7/7 [00:00<00:00, 18.76it/s, avg_epoch_loss=0, epoch=361]
100%|██████████| 7/7 [00:00<00:00, 18.59it/s, avg_epoch_loss=0, epoch=362]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 399


100%|██████████| 2/2 [00:57<00:00, 28.68s/it]


##### End evaluation!!
PEHE VAL = 8.9e+03


100%|██████████| 7/7 [00:00<00:00, 17.92it/s, avg_epoch_loss=0, epoch=400]
100%|██████████| 7/7 [00:00<00:00, 18.28it/s, avg_epoch_loss=0, epoch=401]
100%|██████████| 7/7 [00:00<00:00, 18.51it/s, avg_epoch_loss=0, epoch=402]
100%|██████████| 7/7 [00:00<00:00, 15.12it/s, avg_epoch_loss=0, epoch=403]
100%|██████████| 7/7 [00:00<00:00, 19.52it/s, avg_epoch_loss=0, epoch=404]
100%|██████████| 7/7 [00:00<00:00, 19.76it/s, avg_epoch_loss=0, epoch=405]
100%|██████████| 7/7 [00:00<00:00, 19.89it/s, avg_epoch_loss=0, epoch=406]
100%|██████████| 7/7 [00:00<00:00, 19.75it/s, avg_epoch_loss=0, epoch=407]
100%|██████████| 7/7 [00:00<00:00, 16.47it/s, avg_epoch_loss=0, epoch=408]
100%|██████████| 7/7 [00:00<00:00, 19.64it/s, avg_epoch_loss=0, epoch=409]
100%|██████████| 7/7 [00:00<00:00, 19.59it/s, avg_epoch_loss=0, epoch=410]
100%|██████████| 7/7 [00:00<00:00, 18.13it/s, avg_epoch_loss=0, epoch=411]
100%|██████████| 7/7 [00:00<00:00, 18.39it/s, avg_epoch_loss=0, epoch=412]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 449


100%|██████████| 2/2 [00:58<00:00, 29.48s/it]


##### End evaluation!!
PEHE VAL = 8.03e+03


100%|██████████| 7/7 [00:00<00:00, 18.10it/s, avg_epoch_loss=0, epoch=450]
100%|██████████| 7/7 [00:00<00:00, 18.15it/s, avg_epoch_loss=0, epoch=451]
100%|██████████| 7/7 [00:00<00:00, 18.30it/s, avg_epoch_loss=0, epoch=452]
100%|██████████| 7/7 [00:00<00:00, 18.04it/s, avg_epoch_loss=0, epoch=453]
100%|██████████| 7/7 [00:00<00:00, 18.11it/s, avg_epoch_loss=0, epoch=454]
100%|██████████| 7/7 [00:00<00:00, 18.33it/s, avg_epoch_loss=0, epoch=455]
100%|██████████| 7/7 [00:00<00:00, 17.77it/s, avg_epoch_loss=0, epoch=456]
100%|██████████| 7/7 [00:00<00:00, 17.50it/s, avg_epoch_loss=0, epoch=457]
100%|██████████| 7/7 [00:00<00:00, 17.98it/s, avg_epoch_loss=0, epoch=458]
100%|██████████| 7/7 [00:00<00:00, 18.45it/s, avg_epoch_loss=0, epoch=459]
100%|██████████| 7/7 [00:00<00:00, 18.27it/s, avg_epoch_loss=0, epoch=460]
100%|██████████| 7/7 [00:00<00:00, 18.32it/s, avg_epoch_loss=0, epoch=461]
100%|██████████| 7/7 [00:00<00:00, 18.41it/s, avg_epoch_loss=0, epoch=462]
100%|██████████| 7/7 [00:

Start validation!!!
Epoch: 499


100%|██████████| 2/2 [00:59<00:00, 29.80s/it]


##### End evaluation!!
PEHE VAL = 8.54e+03
Training complete.
Evaluating the model...


100%|██████████| 12/12 [9:47:05<00:00, 2935.49s/it]
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name           | Type        | Params | Mode 
-------------------------------------------------------
0 | balancer       | Sequential  | 1.6 K  | train
1 | prior_encoder0 | Sequential  | 0      | train
2 | prior_encoder1 | Sequential  | 0      | train
3 | cond_mean0     | Sequential  | 701    | train
4 | cond_mean1     | Sequential  | 701    | train
5 | cond_std0      | Sequential  | 701    | train
6 | cond_std1      | Sequential  | 701    | train
7 | flows0         | DSFMarginal | 892    | train
8 | flows1         | DSFMarginal | 892    | train
-------------------------------------------------------
6.2 K     Trainable params
0         Non-trainable params
6.2 K     Total params
0.025     Total estimated model params size (MB)
45        Modules in train mode
0         Modules in eval mode


Fit and evaluate NOFLITE ...


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_steps=10000` reached.


Fit and evaluate BART ...


Multiprocess sampling (4 chains in 4 jobs)
PGBART: [mu]


Output()

Sampling 4 chains for 500 tune and 500 draw iterations (2_000 + 2_000 draws total) took 20 seconds.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
Sampling: [mu, y]


Output()

X_extended (3000, 11)


Sampling: [mu, y]


Output()

Sampling: [mu, y]


Output()

X_extended (3000, 11)


Sampling: [mu, y]


Output()

Fit and evaluate CMGP ...
Fit and evaluate CEVAE ...
Using cpu


100%|██████████| 7000/7000 [00:32<00:00, 217.48it/s]


Fit and evaluate FCCN ...


100%|██████████| 20000/20000 [02:06<00:00, 157.81it/s]


Fit and evaluate GANITE ...


Training Counterfactual GAN: 100%|██████████| 10000/10000 [08:29<00:00, 19.62it/s]
Training ITE GAN: 100%|██████████| 10000/10000 [09:33<00:00, 17.44it/s]


Fit and evaluate DKLITE ...


  5%|▌         | 51/1000 [00:02<00:55, 17.17it/s]


Fit and evaluate diffpo ...


100%|██████████| 7/7 [00:00<00:00, 13.78it/s, avg_epoch_loss=0, epoch=0]
100%|██████████| 7/7 [00:00<00:00, 15.66it/s, avg_epoch_loss=0, epoch=1]
100%|██████████| 7/7 [00:00<00:00, 13.84it/s, avg_epoch_loss=0, epoch=2]
100%|██████████| 7/7 [00:00<00:00, 14.55it/s, avg_epoch_loss=0, epoch=3]
100%|██████████| 7/7 [00:00<00:00, 14.66it/s, avg_epoch_loss=0, epoch=4]
  0%|          | 0/7 [00:00<?, ?it/s]


KeyboardInterrupt: 

In [4]:
n

1

## Evaluating Probability Calibrations

In [6]:
setup_A = {"n":3001, "d": 10, "gamma":1, "alpha": 0.1, "nexps": 1}
setup_B = {"n":3001, "d": 10, "gamma":0, "alpha": 0.1, "nexps": 1}

learner = RandomForestRegressor
learner_name = "RF"
MC_samples = 100
normalized_conformal = True
if normalized_conformal:
    normalized_conformal_name = "normalized"
else:
    normalized_conformal_name = "nonnormalized"
max_min_y = True

PROB_NSIM = 1000

In [None]:
for i, setup in enumerate([setup_A, setup_B]):
    list_p_values_pseudo_MC_T = []
    list_p_values_MC_T = []
    list_p_values_CT = []
    list_p_values_y0 = []
    list_p_values_y1 = []
    list_p_values_oracle = []
    list_p_values_pseudo_MC_T_unweighted = []
    list_p_values_MC_T_unweighted = []
    list_p_values_CT_unweighted = []
    list_p_values_y0_unweighted = []
    list_p_values_y1_unweighted = []
    setup_name = "A" if i == 0 else "B"
    for n in tqdm(range(PROB_NSIM)):
        ds = generate_data(**setup)
        ds_train, ds_test = train_test_split(ds[0], test_size=1)
        W_train = ds_train['T'].to_numpy()
        y_train = ds_train['Y'].to_numpy()
        y1_train = ds_train['Y1'].to_numpy()
        y0_train = ds_train['Y0'].to_numpy()
        X_train = ds_train[['X'+str(i) for i in range(1,setup_A["d"]+1)]].to_numpy()
        ps_train = ds_train['ps'].to_numpy()
        ite_train = y1_train - y0_train

        W_test = ds_test['T'].to_numpy()
        y_test = ds_test['Y'].to_numpy()
        y1_test = ds_test['Y1'].to_numpy()
        y0_test = ds_test['Y0'].to_numpy()
        X_test = ds_test[['X'+str(i) for i in range(1,setup_A["d"]+1)]].to_numpy()
        ps_test = ds_test['ps'].to_numpy()
        ite_test = y1_test - y0_test

        ## Initialize and fit the learners (weighted)
        conformal_pseudo_MC_T_Learner = CMC_T_Learner(
            learner(),
            learner(),
            normalized_conformal=normalized_conformal,
            pseudo_MC=True,
            MC_samples=MC_samples,
            max_min_y=max_min_y
        )
        conformal_pseudo_MC_T_Learner.fit(X_train, y_train, W_train, ps_train)

        conformal_MC_T_Learner = CMC_T_Learner(
            learner(),
            learner(),
            normalized_conformal=normalized_conformal,
            pseudo_MC=False,
            MC_samples=MC_samples,
            max_min_y=max_min_y
        )
        conformal_MC_T_Learner.fit(X_train, y_train, W_train, ps_train)

        conformal_CT_learner = CCT_Learner(learner(), learner(), normalized_conformal=normalized_conformal)
        conformal_CT_learner.fit(X_train, y_train, W_train, p=ps_train)
        (X_train_nuisance, X_train_cal,
            y_train_nuisance, y_train_cal,
            y0_train_nuisance, y0_train_cal,
            y1_train_nuisance, y1_train_cal,
            W_train_nuisance, W_train_cal,
            ite_train_nuisance, ite_train_cal) = train_test_split(
                X_train, y_train, y0_train, y1_train, W_train, ite_train, test_size=0.5, random_state=n
        )
        conformal_ite_oracle = WrapRegressor(learner())
        conformal_ite_oracle.fit(X_train_nuisance, ite_train_nuisance)
        conformal_ite_oracle.calibrate(X_train_cal, ite_train_cal, cps=True)
        # p-values
        list_p_values_pseudo_MC_T.append(conformal_pseudo_MC_T_Learner.predict_p_value(X_test, ite_test))
        list_p_values_MC_T.append(conformal_MC_T_Learner.predict_p_value(X_test, ite_test))
        list_p_values_CT.append(conformal_CT_learner.predict_p_value(X_test, ite_test, p=ps_test))
        list_p_values_y0.append(conformal_CT_learner.predict_p_value_y0(X_test, y0_test, p=ps_test))
        list_p_values_y1.append(conformal_CT_learner.predict_p_value_y1(X_test, y1_test, p=ps_test))
        list_p_values_oracle.append(conformal_ite_oracle.predict_cps(X_test, y=ite_test))

        ## Initialize and fit the learners (unweighted)
        conformal_pseudo_MC_T_Learner = CMC_T_Learner(
            learner(),
            learner(),
            normalized_conformal=normalized_conformal,
            pseudo_MC=True,
            MC_samples=MC_samples,
            max_min_y=max_min_y
        )
        conformal_pseudo_MC_T_Learner.fit(X_train, y_train, W_train)

        conformal_MC_T_Learner = CMC_T_Learner(
            learner(),
            learner(),
            normalized_conformal=normalized_conformal,
            pseudo_MC=False,
            MC_samples=MC_samples,
            max_min_y=max_min_y
        )
        conformal_MC_T_Learner.fit(X_train, y_train, W_train)

        conformal_CT_learner = CCT_Learner(learner(), learner(), normalized_conformal=normalized_conformal)
        conformal_CT_learner.fit(X_train, y_train, W_train)
        # p-values
        list_p_values_pseudo_MC_T_unweighted.append(conformal_pseudo_MC_T_Learner.predict_p_value(X_test, ite_test))
        list_p_values_MC_T_unweighted.append(conformal_MC_T_Learner.predict_p_value(X_test, ite_test))
        list_p_values_CT_unweighted.append(conformal_CT_learner.predict_p_value(X_test, ite_test))
        list_p_values_y0_unweighted.append(conformal_CT_learner.predict_p_value_y0(X_test, y0_test))
        list_p_values_y1_unweighted.append(conformal_CT_learner.predict_p_value_y1(X_test, y1_test))



    dict_p_values = {
            "pseudo_MC_T": np.concatenate(list_p_values_pseudo_MC_T),
            "MC_T": np.concatenate(list_p_values_MC_T),
            "CT": np.concatenate(list_p_values_CT),
            "y0": np.concatenate(list_p_values_y0),
            "y1": np.concatenate(list_p_values_y1),
            "oracle": np.concatenate(list_p_values_oracle),
            "pseudo_MC_T_unweighted": np.concatenate(list_p_values_pseudo_MC_T_unweighted),
            "MC_T_unweighted": np.concatenate(list_p_values_MC_T_unweighted),
            "CT_unweighted": np.concatenate(list_p_values_CT_unweighted),
            "y0_unweighted": np.concatenate(list_p_values_y0_unweighted),
            "y1_unweighted": np.concatenate(list_p_values_y1_unweighted),
    }
    df_p_values = pd.DataFrame(dict_p_values)
    if max_min_y:
        df_p_values.to_csv(f"../../results/outputs/alaa/setup{setup_name}/p_values/simulations_{setup_name}_{learner_name}_{normalized_conformal_name}_max_min_y_p_values.csv", index=False)
    else:
        df_p_values.to_csv(f"../../results/outputs/alaa/setup{setup_name}/p_values/simulations_{setup_name}_{learner_name}_{normalized_conformal_name}_p_values.csv", index=False)