# Synthetic data generation using VAEs

## 1. Introduction


<img src="./data/VAE_Basic.png" alt="VAE Structure" width="600"/>


#### 🔍 What is a VAE?

A **Variational Autoencoder** is a **generative model** that learns how to compress data into a **latent space** and then **reconstruct** it. Unlike traditional autoencoders, VAEs are **probabilistic** — they model distributions rather than just fixed values.

---

#### 🧠 How It Work?

A VAE has two main parts:

1. **Encoder**: Compresses the input `x` into a probability distribution over latent variables `z`, specifically a **mean** (μ) and **standard deviation** (σ).
2. **Decoder**: Samples a latent vector `z` from this distribution and reconstructs the input `x'`.

---

* We assume the latent space follows a **normal distribution**:

  $$
  z \sim \mathcal{N}(\mu(x), \sigma(x)^2)
  $$

* The **loss function** combines:

  1. **Reconstruction loss** (how well the output matches the input), e.g. Mean Squared Error (MSE).
  2. **KL divergence** between the learned distribution and a standard normal distribution $\mathcal{N}(0, 1)$:

  $$
  \mathcal{L} = \mathbb{E}_{q(z|x)}[\log p(x|z)] - \text{KL}(q(z|x) \parallel p(z))
  $$

  This encourages the latent space to be smooth and continuous.

---

#### ✅ Why It Matters?

* VAEs are great for **generating new data**, **dimensionality reduction**, and learning **structured latent representations**.
* Unlike standard autoencoders, VAEs **can generate realistic and varied samples** because of their probabilistic nature.

---



## 2. SDV (Synthetic Data Vault)

It's an MIT package for creating sythetic data

### A. Import the data and packages

In [4]:
from sdv.metadata import SingleTableMetadata # metadata extractor
from sdv.single_table import TVAESynthesizer # VAE data synthesizer 
import pprint # prints dictionaries nicely
import pandas as pd

In [5]:
# this data was imported from https://www.kaggle.com/datasets/uciml/german-credit
df = pd.read_csv("./data/german_credit_data.csv", index_col=0)

In [6]:
df.describe(include='all').T.sort_values("freq")

Unnamed: 0,count,unique,top,freq,mean,std,min,25%,50%,75%,max
Checking account,606.0,3.0,little,274.0,,,,,,,
Purpose,1000.0,8.0,car,337.0,,,,,,,
Saving accounts,817.0,4.0,little,603.0,,,,,,,
Sex,1000.0,2.0,male,690.0,,,,,,,
Housing,1000.0,3.0,own,713.0,,,,,,,
Age,1000.0,,,,35.546,11.375469,19.0,27.0,33.0,42.0,75.0
Job,1000.0,,,,1.904,0.653614,0.0,2.0,2.0,2.0,3.0
Credit amount,1000.0,,,,3271.258,2822.736876,250.0,1365.5,2319.5,3972.25,18424.0
Duration,1000.0,,,,20.903,12.058814,4.0,12.0,18.0,24.0,72.0


In [7]:
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(df)
pprint.pprint(metadata.to_dict())

{'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1',
 'columns': {'Age': {'sdtype': 'numerical'},
             'Checking account': {'sdtype': 'categorical'},
             'Credit amount': {'sdtype': 'numerical'},
             'Duration': {'sdtype': 'numerical'},
             'Housing': {'sdtype': 'categorical'},
             'Job': {'sdtype': 'categorical'},
             'Purpose': {'sdtype': 'categorical'},
             'Saving accounts': {'sdtype': 'categorical'},
             'Sex': {'sdtype': 'categorical'}}}


In [8]:
# update the types
metadata.update_column(column_name="Saving accounts",
sdtype="categorical")

In [9]:
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(df)
pprint.pprint(metadata.to_dict())

{'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1',
 'columns': {'Age': {'sdtype': 'numerical'},
             'Checking account': {'sdtype': 'categorical'},
             'Credit amount': {'sdtype': 'numerical'},
             'Duration': {'sdtype': 'numerical'},
             'Housing': {'sdtype': 'categorical'},
             'Job': {'sdtype': 'categorical'},
             'Purpose': {'sdtype': 'categorical'},
             'Saving accounts': {'sdtype': 'categorical'},
             'Sex': {'sdtype': 'categorical'}}}


### B. Train a VAE to generate sythetic data

In [12]:
synthesizer = TVAESynthesizer(metadata, epochs=10000,verbose=True)
synthesizer.fit(df)
synthetic_data = synthesizer.sample(num_rows=10000)

Loss: -5.670: 100%|███████████████████████| 10000/10000 [09:33<00:00, 17.45it/s]


In [14]:
synthetic_data.describe(include='all').T.sort_values("freq")

Unnamed: 0,count,unique,top,freq,mean,std,min,25%,50%,75%,max
Checking account,5827.0,3.0,little,2693.0,,,,,,,
Purpose,10000.0,8.0,car,3047.0,,,,,,,
Saving accounts,8278.0,4.0,little,6067.0,,,,,,,
Sex,10000.0,2.0,male,7100.0,,,,,,,
Housing,10000.0,3.0,own,7551.0,,,,,,,
Age,10000.0,,,,36.0616,11.535273,19.0,27.0,33.0,43.0,75.0
Job,10000.0,,,,1.9079,0.60684,0.0,2.0,2.0,2.0,3.0
Credit amount,10000.0,,,,3136.5074,2470.42188,250.0,1442.0,2334.5,3880.25,17287.0
Duration,10000.0,,,,20.482,11.659283,4.0,12.0,17.0,24.0,65.0


### C. Validate results uning a Random forrest classifier