# Conditional tabular DDPM Exploration

1. Simulate tabular data using the example from Syn
2. Build tabular DDPM from scratch based on: https://medium.com/mlearning-ai/enerating-images-with-ddpms-a-pytorch-implementation-cef5a2ba8cb1
3. Train and evaluate the generation result
4. Modify reverse process to achieve conditional generation via missing value imputation
5. If it didn't work, try conditioning 



In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

from sklearn.preprocessing import QuantileTransformer
from scipy.stats import norm


import pickle
from tqdm import tqdm
import joblib

import matplotlib.pyplot as plt
import seaborn as sns
# import ptitprince as pt

from utils_data import TrueSampler
from utils_model import MLPDiffusionContinuous

from ddpm import MyDDPM, generate_samples, generate_imputation

import sys

sys.path.insert(0, "../tab-ddpm/pass-inference/syninf/")
from utils_viz import plot_distribution, compare_distributions_grid, heatmap_correlation
from utils_syninf import catboost_pred_model, test_rmse

seed = 2024

# Inference

Run `python train_script.py` to get model checkpoints and saved pipeline for quantile transformation

In [4]:
# Initialization #

sigma = 0.2

d_in = 8
d_time = 128
hidden_dims = [512, 256, 256, 256, 256, 128]
n_steps = 1000

device = "cuda:7"


noise_pred_network = MLPDiffusionContinuous(
    d_in=d_in, hidden_dims=hidden_dims, dim_t=d_time
)
tabular_ddpm = MyDDPM(network=noise_pred_network, n_steps=n_steps, device=device)


# Loading trained model for inference
tabular_ddpm.load_state_dict(torch.load("./ckpt/tabular_ddpm.pt"))

# Loading quantile transformer pipeline used during training
qt = joblib.load("./ckpt/qt_train.joblib")

# True sample
yx_true_unnorm = np.load("./data/yx.npy")
yx_true_norm = qt.transform(yx_true_unnorm)

## Conditional generation via missing value imputation technique

* RePaint: https://arxiv.org/pdf/2201.09865.pdf
* TabSyn: https://arxiv.org/pdf/2310.09656v1.pdf

In [5]:
# Masking the response column
yx_mask = np.ones(yx_true_unnorm.shape)
yx_mask[:, 0] = 0

# Conditional generation
yx_true_norm = torch.tensor(yx_true_norm, dtype=torch.float32)
yx_mask = torch.tensor(yx_mask, dtype=torch.float32)
yx_fake_pred = generate_imputation(
    tabular_ddpm, yx_true_norm, yx_mask, resampling_steps=10
)

                                                                          

### Supervised evaluation on independent test set

I found the generation speed is the same for n_samples = 200 versus n_samples = 5000

In [7]:
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

# Generate test data #
n_test = 500
true_sampler = TrueSampler(sigma=sigma)
X, y, mean_y = true_sampler.sample(n_test, return_mean=True) # return E(y|X) as well
yx_test = np.concatenate([y[:, None], X], axis=1)

np.save("./data/yx_test.npy", yx_test)
np.save("./data/mean_y_test.npy", mean_y)



# Quantile transformation #
qt = QuantileTransformer(output_distribution="normal", random_state=seed)
qt.fit(yx_test)
yx_norm = qt.transform(yx_test)
yx_norm = torch.tensor(yx_norm, dtype=torch.float32)

joblib.dump(qt, "./ckpt/qt_test.joblib")


# Get the test data and the corresponding quantile transformer #
yx_unnorm = np.load("./data/yx_test.npy")
qt = joblib.load("./ckpt/qt_test.joblib")
yx_norm = qt.transform(yx_unnorm)


X shape:  (500, 7)
y shape:  (500,)




In [8]:
# Conditional generation for predicting y #
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)


n_test = yx_norm.shape[0]
D = 200  # Monte Carlo size
D_batch = 100  # non-linear growth, so split into batches

batch_sizes = [D_batch] * (D // D_batch) + ([D % D_batch] if D % D_batch != 0 else [])

y_dict = {
    "y_true": yx_unnorm[:, 0],
    "y_pred_list": [],
}

for i, batch_size in enumerate(batch_sizes):
    print(f"Processing batch {i} with batch size {batch_size}...")

    # stacking all together for fast processing
    input_norm = np.vstack([yx_norm] * batch_size)

    input_norm = torch.tensor(input_norm, dtype=torch.float32)
    input_mask = torch.ones(input_norm.shape, dtype=torch.float32)
    input_mask[:, 0] = 0

    output_norm = generate_imputation(
        tabular_ddpm, input_norm, input_mask, resampling_steps=10
    )
    output_unnorm = qt.inverse_transform(output_norm.cpu().detach().numpy())

    # save the result
    y_dict["y_pred_list"].extend(
        [temp[:, 0] for temp in np.split(output_unnorm, batch_size, axis=0)]
    )
    pickle.dump(y_dict, open(f"./result/y_pred_dict.pkl", "wb"))

Processing batch 0 with batch size 100...


                                                                          

Processing batch 1 with batch size 100...


                                                                          

In [9]:
y_dict = pickle.load(open("./result/y_pred_dict.pkl", "rb"))

y_true = y_dict["y_true"]
y_pred_list = y_dict["y_pred_list"]

y_true.shape

(500,)

In [22]:
# Evaluation #

y_dict = pickle.load(open("./result/y_pred_dict.pkl", "rb"))

mean_y = np.load("./data/mean_y_test.npy")
sigma_x = np.load("./data/sigma_x_test.npy")
y_true = y_dict["y_true"]
y_pred_list = y_dict["y_pred_list"]
y_pred_array = np.stack(y_pred_list, axis=0)  # (D, n_test)

# RMSE
y_pred_synthetic = y_pred_array.mean(axis=0)
rmse = np.sqrt(np.mean((y_true - y_pred_synthetic) ** 2))
print("RMSE:", rmse)


RMSE: 0.21656490477695806


In [24]:
# Standard error 

np.sqrt(np.mean((y_pred_array - y_true[None, :])**2, axis=1))

array([0.31416175, 0.32583073, 0.31993672, 0.29716668, 0.31620187,
       0.32154311, 0.30349527, 0.31919774, 0.31875214, 0.31224021,
       0.31197598, 0.307265  , 0.32297158, 0.31934985, 0.31551817,
       0.31899327, 0.30701574, 0.31415524, 0.30916237, 0.31745115,
       0.29576875, 0.30914911, 0.31019274, 0.31510035, 0.3237752 ,
       0.31517443, 0.30565522, 0.32445722, 0.30586117, 0.31486455,
       0.30223765, 0.31257032, 0.32340411, 0.31598996, 0.31551358,
       0.30673915, 0.32241415, 0.312895  , 0.31701488, 0.30013982,
       0.32313256, 0.32650639, 0.30836197, 0.30942243, 0.31509253,
       0.30022467, 0.32694155, 0.31156378, 0.31890919, 0.31872372,
       0.30989444, 0.31324203, 0.30661241, 0.32430769, 0.32574809,
       0.32608053, 0.30702242, 0.3141259 , 0.31038802, 0.30719156,
       0.31813655, 0.3131604 , 0.30893857, 0.32688024, 0.31050423,
       0.3144743 , 0.32758273, 0.30990415, 0.32262353, 0.32271269,
       0.3093788 , 0.3187489 , 0.31491559, 0.31364911, 0.30494

### Conformal prediction interval

In [11]:
yx_train_unnorm = np.load("./data/yx.npy")
yx_test_unnorm = np.load("./data/yx_test.npy")

# Simulate transfer learning and fine-tuning: Use a subset of samples for training the traditional predictive model
n_traditional = 1000
yx_train_unnorm = yx_train_unnorm[:n_traditional]
yx_test_unnorm = yx_test_unnorm[:n_traditional]

columns_names = ["y", "x1", "x2", "x3", "x4", "x5", "x6", "x7"]
df_train = pd.DataFrame(yx_train_unnorm)
df_test = pd.DataFrame(yx_test_unnorm)
df_train.columns = df_test.columns = columns_names

n, n_test = len(yx_train_unnorm), len(yx_test_unnorm)
r_model, r_val = 0.8, 0.2
n_model, n_val = int(n * r_model), int(n * r_val)

df_model = df_train.iloc[:n_model, :]
df_val = df_train.iloc[n_model:, :]


# train predictive model with early stopping
model_fit = catboost_pred_model(df_model, df_val, num_features_list=columns_names[1:])

no null features, using all specified features for training
Learning rate set to 0.049168
0:	learn: 0.5503667	test: 0.5070400	best: 0.5070400 (0)	total: 50.4ms	remaining: 50.3s
1:	learn: 0.5378831	test: 0.4957794	best: 0.4957794 (1)	total: 53.2ms	remaining: 26.5s
2:	learn: 0.5245820	test: 0.4853841	best: 0.4853841 (2)	total: 55.6ms	remaining: 18.5s
3:	learn: 0.5124515	test: 0.4773442	best: 0.4773442 (3)	total: 57.8ms	remaining: 14.4s
4:	learn: 0.5013724	test: 0.4686287	best: 0.4686287 (4)	total: 59.5ms	remaining: 11.8s
5:	learn: 0.4905277	test: 0.4609893	best: 0.4609893 (5)	total: 61.2ms	remaining: 10.1s
6:	learn: 0.4807613	test: 0.4533791	best: 0.4533791 (6)	total: 62.6ms	remaining: 8.88s
7:	learn: 0.4712948	test: 0.4467117	best: 0.4467117 (7)	total: 64ms	remaining: 7.93s
8:	learn: 0.4608339	test: 0.4392130	best: 0.4392130 (8)	total: 65.4ms	remaining: 7.2s
9:	learn: 0.4517376	test: 0.4328325	best: 0.4328325 (9)	total: 66.8ms	remaining: 6.62s
10:	learn: 0.4437129	test: 0.4264851	best: 

In [14]:
# Evaluate the result on the test set
rmse = test_rmse(model_fit, df_test, columns_names[1:])
print("RMSE on the test set for traditional approach:", rmse)

RMSE on the test set for traditional approach: 0.22782445705253043
