# ClavaDDPM: Multi-relational Data Synthesis with Cluster-guided Diffusion Models

Recent tabular data synthesis research has focused on single tables, while real-world applications often involve complex, interconnected tables. Existing methods for multi-relational data synthesis struggle with scalability and long-range dependencies. This paper introduces Cluster Latent Variable guided Denoising Diffusion Probabilistic Models (ClavaDDPM), using clustering labels to model inter-table relationships, particularly foreign key constraints. ClavaDDPM efficiently propagates latent variables across tables, capturing long-range dependencies. Evaluations show ClavaDDPM outperforms existing methods on multi-table data and remains competitive for single-table data.

In the following sections, we will delve deeper into the implementation of this method. The notebook is organized as follows:

1. [Imports and Setup]()


2. [Load Configuration]()


3. [Data Loading and Preprocessing]()
    
    
4. [ClavaDDPM Algorithm]()

    4.1. [Overview]()
    
    4.2. [Clustring]()
    
    4.3. [Model Training]()
    
    4.4. [Model Sampling]()
    
    
    
6. [Model Evaluation]()

    6.1. [Multi-Table Metrics]()
    
    6.2. [Single-Table Metrics]()

# Imports and Setup

In this section, we import all necessary libraries and modules for setting up the environment. This includes libraries for logging, argument parsing, file path management, and configuration loading. We also import essential packages for data loading, model creation, and training, such as PyTorch and numpy, along with custom modules specific to the ClavaDDPM.

In [1]:
import json

from midst_models.multi_table_ClavaDDPM.complex_pipeline import (
    clava_clustering,
    clava_training,
    clava_load_pretrained,
    clava_load_synthesized_data,
    clava_load_synthesized_data_updated,
    clava_synthesizing,
    clava_eval,
    load_configs,
)
from midst_models.multi_table_ClavaDDPM.pipeline_modules import load_multi_table
from midst_models.multi_table_ClavaDDPM.report_utils import get_multi_metadata

  from .autonotebook import tqdm as notebook_tqdm
Found Intel OpenMP ('libiomp') and LLVM OpenMP ('libomp') loaded at
the same time. Both libraries are known to be incompatible and this
can cause random crashes or deadlocks on Linux when loaded in the
same Python program.
Using threadpoolctl may cause crashes or deadlocks. For more
information and possible workarounds, please see
    https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md



ModuleNotFoundError: No module named 'faiss'

# Load Configuration

In this section, we establish the setup for model training by loading the configuration file, which includes the necessary parameters and settings for the training process. The configuration file, stored in `json` format, is read and parsed into a dictionary. We print out the entire configuration file in the code cell below and will explain the hyperparameters in more detail further down to clarify.

A sample configuration file is available at `configs/berka.json`, where general parameters can be modified as needed.

In [None]:
# Load config
config_path = "configs/berka.json"
configs, save_dir = load_configs(config_path)

# Display config
json_str = json.dumps(configs, indent=4)
print(json_str)

# Data Loading and Preprocessing

In this section, we load and preprocess the dataset based on the configuration settings. We demonstrate the dataset's metadata and parent-child relationships to provide a clearer understanding of its structure. Following this, we perform clustering to preprocess the data, facilitating the training process for the ClavaDDPM model.

In this notebook, we use all the tables from the Berka dataset. You can access the Berka dataset files for ClavaDDPM [here](https://drive.google.com/drive/folders/1rmJ_E6IzG25eCL3foYAb2jVmAstXktJ1?usp=drive_link). The BERKA dataset is a comprehensive banking dataset originally released by the Czech bank ČSOB for the Financial Modeling and Analysis (FMA) competition in 1999. It provides detailed financial data on transactions, accounts, loans, credit cards, and demographic information for thousands of customers over multiple years.
In this section, we load and preprocess the dataset based on the configuration settings. 
The following files are needed to be present in the data directory:
- `{table}.csv`: The train susbet from the Berka dataset for all the tables.
- `test.csv`: The transactions susbet from the Berka dataset used for evaluation.
- `{table}_label_encoders.pkl`: The label encoders used to encode the table if you are using the already preprocessed data from the shared files.
- `{table}_domain.json`: This file contains the domain information for each column in the transactions table. Sample domain files are available in `configs/domain_files`
- `dataset_meta.json`: The configuration file defines the relationships between different tables in the dataset. A sample configuration file is available at `configs/dataset_meta.json`.

In [None]:
# Load multi-table dataset
# In this step, we load the multi-table dataset according to the 'dataset_meta.json' file located in the data_dir.
# We organize the multi-table dataset as a dictionary of tables, a list of relation orders, and a dictionary of dataset metadata.
tables, relation_order, dataset_meta = load_multi_table(configs["general"]["data_dir"])
print("")

# Tables is a dictionary of the multi-table dataset
print(
    "{} We show the keys of the tables dictionary below {}".format("=" * 20, "=" * 20)
)
print(list(tables.keys()))
print("")

# Relation order is the topological order of the multi-table dataset
print("{} We show the relation order below {}".format("=" * 20, "=" * 20))
for i, (from_item, to_item) in enumerate(relation_order):
    print("Relation {}: {} ---> {}".format(i, from_item, to_item))
print("")

# Visualize the parent-child relationship within the multi-table dataset
multi_meta = get_multi_metadata(tables, relation_order)
multi_meta.visualize()

# ClavaDDPM Algorithm

## Overview


<img src="https://github.com/user-attachments/assets/b190edb0-a16a-47b7-9534-cab46e40d0d0" alt="ClavaDDPM Model Pipeline" width="960"/>

This section outlines the training process for the ClavaDDPM model. The diagram above, taken from the original paper, illustrates the main steps: 

(a) **Latent learning and table augmentation (steps 1-2):** This step crossponds to clustering section, where we aim to augmente each table with associated clustering labels that used to capture inter-table relationships.

(b) **Training (steps 3-5):** This step corresponds to the model training section, where we train separate conditional diffusion models and the cluster classifier models on each augmented table.

(c) **Synthesis (steps 6-8):** This step corresponds to the model sampling section, where we sample the table size and generate data based on the parent-child constraints (i.e., relation order).

We will implement and demonstrate each section below, step by step.

## Clustering
To get started, this paper first introduces relation-aware clustering to model parent-child constraints and leverages diffusion models for controlled tabular data synthesis. Specifically, [Gaussian Process Latent Variable Models (GPLVM)](https://pyro.ai/examples/gplvm.html) are used to discover low-dimensional manifolds in noisy, high-dimensional spaces. We run the clustering algorithm below to preprocess the data for training the ClavaDDPM model. Additionally, we empirically determine the distribution of table sizes in the dataset, which will be used in the later sampling process.

Mathematically, let $x$ and $y$ denote the child and parent tables, respectively. We consider $k$ clusters and model the distribution of $h = (x; \lambda y)$ with a Gaussian distribution around its corresponding centroid $c$, i.e.,
$$
P(h) = \sum_{c=1}^{k} P(c) P(h \mid c) = \sum_{c=1}^{k} \pi_c \mathcal{N}(h; \mu_c, \Sigma_c),
$$
where the coefficient $\lambda$ is the reweighting term called the parent scale. We opt for diagonal covariance, i.e., $\Sigma_c = \operatorname{diag}(\ldots, \sigma_l^2, \ldots)$, which, when properly optimized, immediately satisfies our assumptions that the foreign key groups are conditionally independent of their parent rows given the cluster. 

A summary of the important parameters for the clustering step includes:
- `parent_scale`: reweighting coefficient $\lambda$ for the parent table. The default value is 1.0. It is not a sensitive factor. Recommended range for tuning: 1.0 to 2.
- `num_clusters`: the number of clustering centers $k$ for child tables. The default is 20. Too few or too many clusters may compromise performance. Recommended range for tuning: 10 to 50.
- `clustering_method`: ClavaDDPM provides two ways to initialize GMM-based clustering. `gmm` uses `kmeans` for initialization, and the default method `both` uses `k-means++`. It is recommended to use `both`.

We expect the clava_cluster function to return following outputs:
- `tables`: containing the updated relational tables with data augmentation, where the latent variable is attached to the parent tables.
- `all_group_lengths_prob_dicts`: which is a dictionary that computes group size distributions for each table, used in the sampling stage to determine the size of the tables to generate.

In [None]:
# Display important clustering parameters
params_clustering = configs["clustering"]
print("{} We show the clustering parameters below {}".format("=" * 20, "=" * 20))
for key, val in params_clustering.items():
    print(f"{key}: {val}")
print("")

# Clustering on the multi-table dataset
tables, all_group_lengths_prob_dicts = clava_clustering(
    tables, relation_order, save_dir, configs
)

## Model Training 

With the clustering results set, we can proceed to launch the training in PyTorch. As the dataset contains multiple tables, we train separate conditional diffusion models and cluster classifier models on each augmented table. Important parameters for the training process include:

- `d_layers`: the dimension of layers in the diffusion model. 
- `num_timesteps`: the number of diffusion steps for adding noise and denoising. 
- `iterations`: the number of training iterations. The default is 10000. Recommended range for tuning: 5000 to 20000.
- `batch_size`: the batch size for training. The default is 4096. 

In [None]:
# Display important sampling parameters
params_sampling = configs["diffusion"]
print(
    "{} We show the important sampling parameters below {}".format("=" * 20, "=" * 20)
)
for key, val in params_sampling.items():
    print(f"{key}: {val}")
print("")

### PyTorch Training from Scratch
The training process is implemented using a custom PyTorch function, specifying parameters such as the number of epochs and checkpoints. Various callbacks are configured to monitor and save the model during training. The training process is then initiated, logging progress and completing the model's training. Finally, the trained models are saved to the specified directory and returned for further use. This process is happening in the `train_model` function, which gets the following inputs:

- `tables`: the relational tables with data augmentation.
- `configs`: the configuration dictionary with hyperparameters and settings for the training process.
- `relation_order`: the parent-child relationships between tables.
- `save_dir`: the directory to save the trained models and logs.


In [None]:
# Relation order is the topological order of the multi-table dataset
print(
    "{} We show the relation order again, each line indicates one conditional generative model {}".format(
        "=" * 20, "=" * 20
    )
)
for i, (from_item, to_item) in enumerate(relation_order):
    print("Relation {}: {} ---> {}".format(i, from_item, to_item))
print("")

# Launch training from scratch
models = clava_training(tables, relation_order, save_dir, configs)

### Loading Pretrained Models
If the training process from scratch takes too long, please run the following command to load pre-trained models and samples.

In [None]:
# Use the pre-trained models
## save_dir was determined when loading the config file
models = clava_load_pretrained(relation_order, save_dir)

# Model Sampling

To assess the trained model's performance, we generate synthetic samples and showcase the results qualitatively. We initiate the generation process by sampling the table size (i.e., the number of rows per table) and performing conditional generation to meet the parent-child constraints (i.e., relation order). Quantitative evaluations will be conducted in the next section.

Important parameters for the sampling process include:
- `batch_size`: Mini-batch size for sampling.
- `classifier_scale` ($\eta$): Controls the magnitude of classifier gradients during guided sampling, balancing sample quality and conditional sampling accuracy. The default value is 1.0. When $\eta = 0$ (disabling classifier conditioning), single column densities (1-way) may improve but fail to capture long-range correlations. When $\eta = 2$, the increased conditioning weight significantly improves the modeling of multi-hop correlations compared to $\eta = 0$. Recommended tuning range: 0 to 2.

<!-- We also conduct a matching process to determine the parent-child table relationship of the generated data. This process uses an approximate nearest neighbor search-based matching technique, providing a universal solution to the multi-parent relational synthesis problem for a child table with multiple parents.

Important parameters are as follows:
- `num_matching_clusters`: Number of clusters used in the matching process.
- `matching_batch_size`: Mini-batch size for table matching.
- `unique_matching`: Boolean flag indicating if the matching result should be unique.
- `no_matching`: Boolean flag to disable the matching process. -->

In [None]:
# Display important sampling parameters
params_sampling = configs["sampling"]
print(
    "{} We show the important sampling parameters below {}".format("=" * 20, "=" * 20)
)
for key, val in params_sampling.items():
    print(f"{key}: {val}")
print("")

### Generating Data from Scratch
To generate synthetic data from scratch, we run the following code cell. This `clava_synthesizing` function gets the following inputs:

- `tables`: the relational tables with data augmentation.
- `relation_order`: the parent-child relationships between tables.
- `save_dir`: the directory to save the synthetic data.
- `all_group_lengths_prob_dicts`: a dictionary that computes group size distributions for each table, used in the sampling stage to determine the size of the tables to generate.
- `models`: the trained diffusion models.
- `configs`: the configuration dictionary with hyperparameters and settings for the sampling process.
- `sample_scale`: the scale factor for the sampling process.

The synthetic data will be saved in the specified output directory.

In [None]:
# Generate synthetic data from scratch
cleaned_tables, synthesizing_time_spent, matching_time_spent = clava_synthesizing(
    tables,
    relation_order,
    save_dir,
    all_group_lengths_prob_dicts,
    models,
    configs,
    sample_scale=1 if "debug" not in configs else configs["debug"]["sample_scale"],
)

Finally, as some integer values are saved as strings during this process, we convert them back to integers for further evaluation.

In [None]:
# Cast int values that saved as string to int for further evaluation
for key in cleaned_tables.keys():
    for col in cleaned_tables[key].columns:
        if cleaned_tables[key][col].dtype == "object":
            try:
                cleaned_tables[key][col] = cleaned_tables[key][col].astype(int)
            except ValueError:
                print(f"Column {col} cannot be converted to int.")

# Model Evaluation

### Removing zero decimal values

Since the categorical features might have zero floating-point values, we need to convert them back to integers for evaluation.
The list of categorical columns is provided in the `congis/info_files/{table}_info.json` files. We convert the floating-point values back to integers and save the synthetic data in the same format as the original data.

In [None]:
for table_name, table in cleaned_tables.items():
    info_path = f"configs/info_files/{table_name}_info.json"

    with open(info_path, "r") as f:
        info = json.load(f)

    cat_col_idx = info["cat_col_idx"]
    target_col_idx = info["target_col_idx"]

    task_type = info["task_type"]
    if task_type != "regression":
        cat_col_idx += target_col_idx

    for idx in cat_col_idx:
        col = table.columns[idx]
        table[col] = table[col].apply(
            lambda x: int(float(x)) if str(x).replace(".0", "").isdigit() else x
        )
    table.to_csv(f"{save_dir}/{table_name}_synthetic_updated.csv", index=False)

### Multi-Table Metrics
In this step, we quantitatively evaluate the generated tabular data by computing metrics to determine the accuracy of the predictions, specifically assessing how closely the generated data matches the observed samples in the reference dataset.

In particular, the critical multi-table metrics are as follows:

1. Pair-wise column correlation (k-hop): This metric measures the correlations between columns from tables at a distance k (e.g., 0-hop for columns within the same table, 1-hop for a column and a column from its parent or child table).

2. Average 2-way: This metric computes the average of all k-hop column-pair correlations, taking into account both short-range (k = 0) and longer-range (k > 0) dependencies.

We use the [SDV evaluation API](https://docs.sdv.dev/sdv/multi-table-data/evaluation/data-quality) to obtain these metrics. For more details about the computation process, refer to their documentation.

In [None]:
tables, relation_order, dataset_meta = load_multi_table(configs["general"]["data_dir"])

In order to load existing synthetic data from the stored files you can run the commented code below. 

In [None]:
# # Load original synthesized data
# cleaned_tables = clava_load_synthesized_data(tables.keys(), save_dir) 
# # Load updated synthesized data
# cleaned_tables = clava_load_synthesized_data_updated(tables.keys(), save_dir) 

In [None]:
# Multi-table Evaluation
report = clava_eval(tables, save_dir, configs, relation_order, cleaned_tables)

In [None]:
# Print out the multi-table metrics
n_rows = 3
for key, val in report.items():
    if key in ["hop_relation", "avg_scores", "all_avg_score"]:
        if key == "hop_relation":
            print("{} metrics:".format(key))
            for k, v in val.items():
                if k > 0:
                    print(
                        "{:20}-hop column correlation, format '(Parent Table, Child Table, Column 1, Column 2): Correlation score'".format(
                            k
                        )
                    )
                else:
                    print(
                        "{:20}-hop column correlation, format '(Parent Table, Parent Table, Column 1, Column 2): Correlation score'".format(
                            k
                        )
                    )
                for i_row, (k2, v2) in enumerate(v.items()):
                    if i_row < n_rows:
                        print("{:20} {}: {}".format("", k2, v2))
                print("{:20} ...... other rows are omitted ......\n".format(""))
        elif key == "avg_scores":
            print("{} metrics:".format(key))
            for k, v in val.items():
                print("{:20}-hop column correlation: {}".format(k, v))
        elif key == "all_avg_score":
            print("{:20}: {}".format(key, val))

### Single Table Metrics
While here we study multi-table metrics, we also provide a notebook to evaluate single table metrics for each synthesized table. Please refer to `evalutate_synthetic_data.ipynb` for more details. 

<!-- # Prepare the synthetic data and reference data for single-table metric evaluation
shutil.copy(os.path.join(configs['general']['data_dir'], 'dataset_meta.json'), os.path.join(save_dir, 'dataset_meta.json'))
for table_name in tables.keys():
    shutil.copy(os.path.join(save_dir, table_name, '_final', f'{table_name}_synthetic.csv'), os.path.join(save_dir, f'{table_name}.csv'))
    # uncomment and run the following line if you want to use the pre-synthesized data
    # shutil.copy(os.path.join(pretrained_dir, table_name, '_final', f'{table_name}_synthetic.csv'), os.path.join(save_dir, f'{table_name}.csv'))

    shutil.copy(os.path.join(configs['general']['data_dir'], f'{table_name}_domain.json'), os.path.join(save_dir, f'{table_name}_domain.json'))

test_tables, _, _ = load_multi_table(save_dir, verbose=False)
real_tables, _, _ = load_multi_table(configs['general']['data_dir'], verbose=False)

# Single table metrics
for table_name in tables.keys():
    print(f'Generating report for {table_name}')
    real_data = real_tables[table_name]['df']
    syn_data = cleaned_tables[table_name]
    domain_dict = real_tables[table_name]['domain']

    if configs['general']['workspace_dir'] is not None:
        test_data = test_tables[table_name]['df']
    else:
        test_data = None

    gen_single_report(
        real_data, 
        syn_data,
        domain_dict,
        table_name,
        save_dir,
        alpha_beta_sample_size=200_000,
        test_data=test_data
    ) -->

## References

**Pang, Wei, et al.** "ClavaDDPM: Multi-relational Data Synthesis with Cluster-guided Diffusion Models." *preprint* (2024).

**GitHub Repository:** [ClavaDDPM](https://github.com/weipang142857/ClavaDDPM)