## Imports


## Explanation of Libraries

- `pmlb`: Used for fetching datasets from the Penn Machine Learning Benchmark.
- `numpy`: Provides functions to work with arrays and numerical operations.
- `pandas`: Essential for data manipulation and analysis.
- `scikit-learn`: Includes tools for data preprocessing and model building.
- `tqdm`: For progress bar functionality during iterations.


In [2]:
## If you have not installed pmlb, tqdm please uncomment and run this cell. Then restart the kernel.
# !pip install pmlb
# !pip install tqdm

In [3]:
# Imports
from pmlb import fetch_data
from numpy import concatenate, ndarray, split, zeros
from pandas import DataFrame, concat
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import MinMaxScaler, OneHotEncoder
from tqdm import tqdm
from pandas import DataFrame
from numpy import array, arange
from sklearn.mixture import GaussianMixture
from sklearn.model_selection import train_test_split
from sklearn.compose import ColumnTransformer
import os

## Fetching the Dataset

In this section, we use the `pmlb` library to load the Adult Census Income dataset.
This dataset will be split into training and testing sets to evaluate our model.


# Tabular Synthetic Data Generation with Gaussian Mixture

- This notebook is an example of how to use a synthetic data generation methods based on [GMM](https://scikit-learn.org/stable/modules/generated/sklearn.mixture.GaussianMixture.html) to generate synthetic tabular data with numeric and categorical features.

## Dataset

- The data used is the [Adult Census Income](https://www.kaggle.com/datasets/uciml/adult-census-income) which we will fecth by importing the `pmlb` library (a wrapper for the Penn Machine Learning Benchmark data repository).


## Gaussian Mixture Model (GMM)

A Gaussian Mixture Model (GMM) is a probabilistic model that assumes that all the data points are generated from a mixture of several Gaussian distributions with unknown parameters.
This model is commonly used for clustering or generating synthetic data.

### How GMM Works:

- GMM represents data as a mixture of multiple normal distributions (also known as Gaussian distributions).
- Each Gaussian is defined by a mean and variance, and is assigned a weight (the probability of belonging to that distribution).
- The model uses an iterative algorithm called Expectation-Maximization (EM) to estimate the parameters (mean, variance, and weight) for each Gaussian distribution.
- Once trained, we can use the model to generate new data points based on the learned distribution, or to cluster data by assigning probabilities of belonging to each Gaussian.


# Load the data

### This Dataset is used as a benchmark test for Machine Learning prediction of whether a person makes over $50K a year, based on 14 other features.


In [4]:
# Load data
data = fetch_data("adult")

In [5]:
#  Specify the numeric and categorical columns for different processing tasks
num_cols = ["age", "fnlwgt", "capital-gain", "capital-loss", "hours-per-week"]
cat_cols = [
    "workclass",
    "education",
    "education-num",
    "marital-status",
    "occupation",
    "relationship",
    "race",
    "sex",
    "native-country",
]

In [6]:
data

Unnamed: 0,age,workclass,fnlwgt,education,education-num,marital-status,occupation,relationship,race,sex,capital-gain,capital-loss,hours-per-week,native-country,target
0,39.0,7,77516.0,9,13.0,4,1,1,4,1,2174.0,0.0,40.0,39,1
1,50.0,6,83311.0,9,13.0,2,4,0,4,1,0.0,0.0,13.0,39,1
2,38.0,4,215646.0,11,9.0,0,6,1,4,1,0.0,0.0,40.0,39,1
3,53.0,4,234721.0,1,7.0,2,6,0,2,1,0.0,0.0,40.0,39,1
4,28.0,4,338409.0,9,13.0,2,10,5,2,0,0.0,0.0,40.0,5,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
48837,39.0,4,215419.0,9,13.0,0,10,1,4,0,0.0,0.0,36.0,39,1
48838,64.0,0,321403.0,11,9.0,6,0,2,2,1,0.0,0.0,40.0,39,1
48839,38.0,4,374983.0,9,13.0,2,10,0,4,1,0.0,0.0,50.0,39,1
48840,44.0,4,83891.0,9,13.0,0,1,3,1,1,5455.0,0.0,40.0,39,1


In [7]:
X = data.drop("target", axis=1)
y = data["target"]

In [8]:
X_train, X_test, y_train, y_test = train_test_split(X, y)

## Create and Train the synthetic data generator


In [9]:
# Instatiate Pipelines
num_pipeline = Pipeline([("scaler", MinMaxScaler())])
cat_pipeline = Pipeline(
    [
        ("encoder", OneHotEncoder(sparse_output=False, handle_unknown="ignore")),
    ]
)

In [10]:
num_pipeline.fit(X_train[num_cols], y_train)
cat_pipeline.fit(X_train[cat_cols], y_train)

In [11]:
num_col_idx = len(num_pipeline.get_feature_names_out())
cat_col_idx = len(cat_pipeline.get_feature_names_out())

In [12]:
cat_pipeline.get_feature_names_out()

array(['workclass_0', 'workclass_1', 'workclass_2', 'workclass_3',
       'workclass_4', 'workclass_5', 'workclass_6', 'workclass_7',
       'workclass_8', 'education_0', 'education_1', 'education_2',
       'education_3', 'education_4', 'education_5', 'education_6',
       'education_7', 'education_8', 'education_9', 'education_10',
       'education_11', 'education_12', 'education_13', 'education_14',
       'education_15', 'education-num_1.0', 'education-num_2.0',
       'education-num_3.0', 'education-num_4.0', 'education-num_5.0',
       'education-num_6.0', 'education-num_7.0', 'education-num_8.0',
       'education-num_9.0', 'education-num_10.0', 'education-num_11.0',
       'education-num_12.0', 'education-num_13.0', 'education-num_14.0',
       'education-num_15.0', 'education-num_16.0', 'marital-status_0',
       'marital-status_1', 'marital-status_2', 'marital-status_3',
       'marital-status_4', 'marital-status_5', 'marital-status_6',
       'occupation_0', 'occupation_1',

In [13]:
cat_col_idx

117

In [14]:
trans_num = num_pipeline.transform(X_train[num_cols])
trans_cat = cat_pipeline.transform(X_train[cat_cols])

In [15]:
trans_cat.shape

(36631, 117)

In [16]:
transformed = concatenate([trans_num, trans_cat], axis=1)
transformed.shape

(36631, 122)

In [17]:
# Initialize model
gmm = GaussianMixture(covariance_type="full", random_state=0)

In [18]:
# Fit the model
gmm.fit(transformed)

In [19]:
sample = gmm.sample(n_samples=100)[0]

## Generate new synthetic data


In [20]:
# Inverse transform the generated data
num_gen, cat_gen, _ = split(sample, [num_col_idx, 123], axis=1)

In [21]:
num_gen_2 = num_pipeline.inverse_transform(num_gen)

In [22]:
cat_gen_2 = cat_pipeline.inverse_transform(cat_gen)

In [23]:
result = DataFrame(num_gen_2, columns=num_cols)
result

Unnamed: 0,age,fnlwgt,capital-gain,capital-loss,hours-per-week
0,38.354311,176995.617207,-1956.700188,449.756737,48.476458
1,43.485496,279082.064494,6323.310376,-301.136978,19.971347
2,57.895721,195512.072341,-7481.784133,175.616500,43.942022
3,52.603160,69542.487454,1872.106754,-119.328895,42.677735
4,57.417315,178937.155405,3290.178825,218.300192,57.122104
...,...,...,...,...,...
95,33.094596,132562.665876,5855.802756,366.575522,42.847048
96,60.811961,293593.050324,4573.700357,828.382601,44.220164
97,28.737110,244114.643038,1218.469569,332.540117,34.757925
98,45.455305,302016.126806,3540.238400,-251.616660,61.351343


In [24]:
result = concat(
    [DataFrame(num_gen_2, columns=num_cols), DataFrame(cat_gen_2, columns=cat_cols)],
    axis=1,
)

In [25]:
result

Unnamed: 0,age,fnlwgt,capital-gain,capital-loss,hours-per-week,workclass,education,education-num,marital-status,occupation,relationship,race,sex,native-country
0,38.354311,176995.617207,-1956.700188,449.756737,48.476458,5.0,9.0,13.0,4.0,8.0,1.0,4.0,1.0,39.0
1,43.485496,279082.064494,6323.310376,-301.136978,19.971347,4.0,15.0,10.0,4.0,4.0,3.0,1.0,0.0,39.0
2,57.895721,195512.072341,-7481.784133,175.616500,43.942022,4.0,11.0,9.0,2.0,13.0,1.0,4.0,1.0,39.0
3,52.603160,69542.487454,1872.106754,-119.328895,42.677735,4.0,11.0,9.0,2.0,7.0,0.0,4.0,1.0,39.0
4,57.417315,178937.155405,3290.178825,218.300192,57.122104,4.0,7.0,12.0,4.0,5.0,0.0,2.0,1.0,39.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,33.094596,132562.665876,5855.802756,366.575522,42.847048,4.0,11.0,9.0,2.0,12.0,0.0,4.0,1.0,39.0
96,60.811961,293593.050324,4573.700357,828.382601,44.220164,2.0,11.0,9.0,0.0,4.0,4.0,4.0,1.0,39.0
97,28.737110,244114.643038,1218.469569,332.540117,34.757925,6.0,15.0,10.0,4.0,12.0,1.0,4.0,1.0,39.0
98,45.455305,302016.126806,3540.238400,-251.616660,61.351343,4.0,15.0,10.0,2.0,1.0,4.0,4.0,0.0,39.0


In [26]:
result.to_csv(os.path.join(os.getcwd(),"..","output","census_data_output.tsv"), sep="\t", index=None)