Skip to content

Commit

Permalink
feat: add new gmm based synth for fast synthesis (#269)
Browse files Browse the repository at this point in the history
* feat: Add new GMM model for fast synthesis

* feat: add save and load for new model

* fix: synthesis base class

* fix: linter

* fix: linter warnings
  • Loading branch information
fabclmnt committed May 23, 2023
1 parent 2f6fd89 commit 81abe1d
Show file tree
Hide file tree
Showing 17 changed files with 415 additions and 36 deletions.
24 changes: 15 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,30 @@ Join us on [![Discord](https://img.shields.io/badge/Discord-7289DA?style=for-the
# YData Synthetic
A package to generate synthetic tabular and time-series data leveraging the state of the art generative models.

## 🎊 We have **big news**: v1.0.0 is here
> We have exciting news for you. The new version of `ydata-synthetic` include new and exciting features:
## 🎊 The exciting features:
> These are must try features whne it comes to synthetic data generation:
> - A new streamlit app that delivers the synthetic data generation experience with a UI interface. A low code experience for the quick generation of synthetic data
> - A new fast synthetic data generation model based on Gaussian Mixture. So you can quickstart in the world of synthetic data generation without the need for a GPU.
> - A conditional architecture for tabular data: CTGAN, which will make the process of synthetic data generation easier and with higher quality!
> - A new streamlit app that delivers the synthetic data generation experience with a UI interface
## Synthetic data
### What is synthetic data?
Synthetic data is artificially generated data that is not collected from real world events. It replicates the statistical components of real data without containing any identifiable information, ensuring individuals' privacy.

### Why Synthetic Data?
Synthetic data can be used for many applications:
- Privacy
- Privacy compliance for data-sharing and Machine Learning development
- Remove bias
- Balance datasets
- Augment datasets

# ydata-synthetic
This repository contains material related with Generative Adversarial Networks for synthetic data generation, in particular regular tabular data and time-series.
It consists a set of different GANs architectures developed using Tensorflow 2.0. Several example Jupyter Notebooks and Python scripts are included, to show how to use the different architectures.
This repository contains material related with architectures and models for synthetic data, from Generative Adversarial Networks (GANs) to Gaussian Mixtures.
The repo includes a full ecosystem for synthetic data generation, that includes different models for the generation of synthetic structure data and time-series.
All the Deep Learning models are implemented leveraging Tensorflow 2.0.
Several example Jupyter Notebooks and Python scripts are included, to show how to use the different architectures.

Are you ready to learn more about synthetic data and the bext-practices for synthetic data generation?

## Quickstart
The source code is currently hosted on GitHub at: https://github.com/ydataai/ydata-synthetic
Expand Down Expand Up @@ -78,8 +83,8 @@ The below models are supported:

### Examples
Here you can find usage examples of the package and models to synthesize tabular data.

- Tabular synthetic data generation with CTGAN on adult census income dataset [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Data-Centric-AI-Community/awesome-python-for-data-science/blob/main/workshop-ds/Workshop%20-%20Data-Centric%20AI%20pipelines%20-%20How%20and%20why.ipynb)
- Fast tabular data synthesis on adult census income dataset [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ydataai/ydata-synthetic/blob/master/examples/regular/models/Fast_Adult_Census_Income_Data.ipynb)
- Tabular synthetic data generation with CTGAN on adult census income dataset [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ydataai/ydata-synthetic/blob/master/examples/regular/models/CTGAN_Adult_Census_Income_Data.ipynb)
- Time Series synthetic data generation with TimeGAN on stock dataset [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ydataai/ydata-synthetic/blob/master/examples/timeseries/TimeGAN_Synthetic_stock_data.ipynb)
- More examples are continuously added and can be found in `/examples` directory.

Expand All @@ -106,6 +111,7 @@ In this repository you can find the several GAN architectures that are used to c
- [Cramer GAN (The Cramer Distance as a Solution to Biased Wasserstein Gradients)](https://arxiv.org/abs/1705.10743)
- [CWGAN-GP (Conditional Wassertein GAN with Gradient Penalty)](https://cameronfabbri.github.io/papers/conditionalWGAN.pdf)
- [CTGAN (Conditional Tabular GAN)](https://arxiv.org/pdf/1907.00503.pdf)
- [Gaussian Mixture](https://towardsdatascience.com/gaussian-mixture-models-explained-6986aaf5a95)

### Sequential data
- [TimeGAN](https://papers.nips.cc/paper/2019/file/c9efe5f26cd17ba6216bbe2a7d26d490-Paper.pdf)
Expand Down
203 changes: 203 additions & 0 deletions examples/regular/models/Fast_Adult_Census_Income_Data.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU",
"gpuClass": "standard"
},
"cells": [
{
"cell_type": "code",
"source": [
"#Uncomment to install ydata-synthetic lib\n",
"#!pip install ydata-synthetic"
],
"metadata": {
"id": "fwXSWiYu_tl0",
"pycharm": {
"name": "#%%\n"
}
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Tabular Synthetic Data Generation with Gaussian Mixture\n",
"- This notebook is an example of how to use a synthetic data generation methods based on [GMM](https://scikit-learn.org/stable/modules/generated/sklearn.mixture.GaussianMixture.html) to generate synthetic tabular data with numeric and categorical features.\n",
"\n",
"## Dataset\n",
"- The data used is the [Adult Census Income](https://www.kaggle.com/datasets/uciml/adult-census-income) which we will fecth by importing the `pmlb` library (a wrapper for the Penn Machine Learning Benchmark data repository).\n"
],
"metadata": {
"id": "6T8gjToi_yKA",
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"source": [
"from pmlb import fetch_data\n",
"\n",
"from ydata_synthetic.synthesizers.regular import RegularSynthesizer\n",
"from ydata_synthetic.synthesizers import ModelParameters, TrainParameters"
],
"metadata": {
"id": "Ix4gZ9iSCVZI",
"pycharm": {
"name": "#%%\n"
}
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Load the data"
],
"metadata": {
"id": "I0qyPwoECZ5x",
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"source": [
"# Load data\n",
"data = fetch_data('adult')\n",
"num_cols = ['age', 'fnlwgt', 'capital-gain', 'capital-loss', 'hours-per-week']\n",
"cat_cols = ['workclass','education', 'education-num', 'marital-status', 'occupation', 'relationship', 'race', 'sex',\n",
" 'native-country', 'target']"
],
"metadata": {
"id": "YeFPnJVOMVqd",
"pycharm": {
"name": "#%%\n"
}
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Create and Train the synthetic data generator"
],
"metadata": {
"id": "68MoepO0Cpx6",
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"source": [
"synth = RegularSynthesizer(modelname='fast')\n",
"synth.fit(data=data, num_cols=num_cols, cat_cols=cat_cols)"
],
"metadata": {
"id": "oIHMVgSZMg8_",
"pycharm": {
"name": "#%%\n"
}
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Generate new synthetic data"
],
"metadata": {
"id": "xHK-SRPyDUin",
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"source": [
"synth_data = synth.sample(1000)\n",
"print(synth_data)"
],
"metadata": {
"id": "0aa2g0RLMkqe",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "01808aa4-a700-4385-e7df-b2f7abd162a0",
"pycharm": {
"name": "#%%\n"
}
},
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
" age workclass fnlwgt education education-num \\\n",
"0 38.753654 4 179993.565472 8 10.0 \n",
"1 36.408844 4 245841.807958 9 10.0 \n",
"2 56.251066 4 400895.076058 11 13.0 \n",
"3 26.846605 4 240156.201048 11 10.0 \n",
"4 29.083102 1 5601.059126 11 9.0 \n",
".. ... ... ... ... ... \n",
"995 79.281276 4 30664.183560 1 10.0 \n",
"996 51.423132 4 414524.980527 1 10.0 \n",
"997 17.342915 6 177716.451926 11 13.0 \n",
"998 39.298867 4 132011.369567 15 12.0 \n",
"999 46.977763 2 92662.371635 9 13.0 \n",
"\n",
" marital-status occupation relationship race sex capital-gain \\\n",
"0 4 0 3 4 0 55.771499 \n",
"1 6 7 0 4 1 124.337939 \n",
"2 4 3 3 4 1 27.968087 \n",
"3 4 6 1 4 0 25.065678 \n",
"4 6 3 0 4 0 126.269337 \n",
".. ... ... ... ... ... ... \n",
"995 2 0 3 4 1 4.393001 \n",
"996 4 7 3 2 0 54.841598 \n",
"997 4 4 4 4 0 99.394428 \n",
"998 4 14 1 4 1 97.834797 \n",
"999 4 8 1 4 0 51.258308 \n",
"\n",
" capital-loss hours-per-week native-country target \n",
"0 -1.271118 39.749641 39 1 \n",
"1 -2.114950 44.488198 39 1 \n",
"2 1.541738 40.042696 39 1 \n",
"3 1.148560 39.952615 39 1 \n",
"4 -1.786768 39.808085 39 0 \n",
".. ... ... ... ... \n",
"995 0.224015 50.580637 39 1 \n",
"996 1.319341 4.441194 39 1 \n",
"997 -5.231663 39.779674 39 1 \n",
"998 1.595817 39.731359 13 1 \n",
"999 1.129814 39.838415 39 1 \n",
"\n",
"[1000 rows x 15 columns]\n"
]
}
]
}
]
}
7 changes: 7 additions & 0 deletions src/ydata_synthetic/preprocessing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from ydata_synthetic.preprocessing.regular.processor import RegularDataProcessor
from ydata_synthetic.preprocessing.timeseries.timeseries_processor import TimeSeriesDataProcessor

__all__ = [
"RegularDataProcessor",
"TimeSeriesDataProcessor"
]
7 changes: 6 additions & 1 deletion src/ydata_synthetic/synthesizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
from ydata_synthetic.synthesizers.gan import ModelParameters, TrainParameters
from ydata_synthetic.synthesizers.base import ModelParameters, TrainParameters

__all__ = [
"ModelParameters",
"TrainParameters"
]
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"Implements a GAN BaseModel synthesizer, not meant to be directly instantiated."
from abc import ABC, abstractmethod
from collections import namedtuple
from typing import List, Optional, Union

import pandas as pd
import tqdm

from numpy import array, vstack, ndarray
Expand Down Expand Up @@ -40,19 +42,57 @@
ModelParameters = namedtuple('ModelParameters', _model_parameters, defaults=_model_parameters_df)
TrainParameters = namedtuple('TrainParameters', _train_parameters, defaults=('', None, 300, 50, None, 10, 0.005, True))

@typechecked
class BaseModel(ABC):
"""
Abstract class for synthetic data generation nmodels
The main methods are train (for fitting the synthesizer), save/load and sample (generating synthetic records).
"""
__MODEL__ = None

@abstractmethod
def fit(self, data: Union[DataFrame, array],
num_cols: Optional[List[str]] = None,
cat_cols: Optional[List[str]] = None):
"""
### Description:
Trains and fit a synthesizer model to a given input dataset.
### Args:
`data` (Union[DataFrame, array]): Training data
`num_cols` (Optional[List[str]]) : List with the names of the categorical columns
`cat_cols` (Optional[List[str]]): List of names of categorical columns
### Returns:
**self:** *object*
Fitted synthesizer
"""
...
@abstractmethod
def sample(self, n_samples:int) -> pd.DataFrame:
assert n_samples>0, "Please insert a value bigger than 0 for n_samples parameter."
...

@classmethod
def load(cls, path: str):
...

@abstractmethod
def save(self, path: str):
...

# pylint: disable=R0902
@typechecked
class BaseModel():
class BaseGANModel(BaseModel):
"""
Base class of GAN synthesizer models.
The main methods are train (for fitting the synthesizer), save/load and sample (obtain synthetic records).
Args:
model_parameters (ModelParameters):
Set of architectural parameters for model definition.
"""
__MODEL__ = None

def __init__(
self,
model_parameters: ModelParameters
Expand Down Expand Up @@ -84,7 +124,7 @@ def __init__(
self.gp_lambda = model_parameters.gp_lambda
self.pac = model_parameters.pac

self.processor = None
self.processor=None
if self.__MODEL__ in RegularModels.__members__ or \
self.__MODEL__ == CTGANDataProcessor.SUPPORTED_MODEL:
self.tau = model_parameters.tau_gs
Expand Down Expand Up @@ -183,8 +223,8 @@ def save(self, path):
make_keras_picklable()
dump(self, path)

@staticmethod
def load(path):
@classmethod
def load(cls, path):
"""
### Description:
Loads a saved synthesizer from a pickle.
Expand Down
2 changes: 1 addition & 1 deletion src/ydata_synthetic/synthesizers/regular/cgan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

#Import ydata synthetic classes
from ....synthesizers import TrainParameters
from ....synthesizers.gan import ConditionalModel
from ....synthesizers.base import ConditionalModel

class CGAN(ConditionalModel):
"CGAN model for discrete conditions"
Expand Down
4 changes: 2 additions & 2 deletions src/ydata_synthetic/synthesizers/regular/cramergan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

#Import ydata synthetic classes
from ....synthesizers import TrainParameters
from ....synthesizers.gan import BaseModel
from ....synthesizers.base import BaseGANModel
from ....synthesizers.loss import Mode, gradient_penalty

class CRAMERGAN(BaseModel):
class CRAMERGAN(BaseGANModel):

__MODEL__='CRAMERGAN'

Expand Down

0 comments on commit 81abe1d

Please sign in to comment.