CTGAN Model
===========

In this guide we will go through a series of steps that will let you
discover functionalities of the `CTGAN` model, including how to:

-   Create an instance of `CTGAN`.
-   Fit the instance to your data.
-   Generate synthetic versions of your data.
-   Use `CTGAN` to anonymize PII information.
-   Customize the data transformations to improve the learning process.
-   Specify hyperparameters to improve the output quality.

What is CTGAN?
--------------

The `sdv.tabular.CTGAN` model is based on the GAN-based Deep Learning
data synthesizer which was presented at the NeurIPS 2020 conference by
the paper titled [Modeling Tabular data using Conditional
GAN](https://arxiv.org/abs/1907.00503).

Let\'s now discover how to learn a dataset and later on generate
synthetic data with the same format and statistical properties by using
the `CTGAN` class from SDV.

Quick Usage
-----------

We will start by loading one of our demo datasets, the
`student_placements`, which contains information about MBA students that
applied for placements during the year 2020.

<div class="alert alert-warning">

**Warning**

In order to follow this guide you need to have `ctgan` installed on your
system. If you have not done it yet, please install `ctgan` now by
executing the command `pip install sdv[ctgan]` in a terminal.

</div>

In [1]:
from sdv.demo import load_tabular_demo

data = load_tabular_demo('student_placements')
data.head()

Unnamed: 0,student_id,gender,second_perc,high_perc,high_spec,degree_perc,degree_type,work_experience,experience_years,employability_perc,mba_spec,mba_perc,salary,placed,start_date,end_date,duration
0,17264,M,67.0,91.0,Commerce,58.0,Sci&Tech,False,0,55.0,Mkt&HR,58.8,27000.0,True,2020-07-23,2020-10-12,3.0
1,17265,M,79.33,78.33,Science,77.48,Sci&Tech,True,1,86.5,Mkt&Fin,66.28,20000.0,True,2020-01-11,2020-04-09,3.0
2,17266,M,65.0,68.0,Arts,64.0,Comm&Mgmt,False,0,75.0,Mkt&Fin,57.8,25000.0,True,2020-01-26,2020-07-13,6.0
3,17267,M,56.0,52.0,Science,52.0,Sci&Tech,False,0,66.0,Mkt&HR,59.43,,False,NaT,NaT,
4,17268,M,85.8,73.6,Commerce,73.3,Comm&Mgmt,False,0,96.8,Mkt&Fin,55.5,42500.0,True,2020-07-04,2020-09-27,3.0


As you can see, this table contains information about students which
includes, among other things:

-   Their id and gender
-   Their grades and specializations
-   Their work experience
-   The salary that they where offered
-   The duration and dates of their placement

You will notice that there is data with the following characteristics:

-   There are float, integer, boolean, categorical and datetime values.
-   There are some variables that have missing data. In particular, all
    the data related to the placement details is missing in the rows
    where the student was not placed.

Let us use `CTGAN` to learn this data and then sample synthetic data
about new students to see how well de model captures the characteristics
indicated above. In order to do this you will need to:

-   Import the `sdv.tabular.CTGAN` class and create an instance of it.
-   Call its `fit` method passing our table.
-   Call its `sample` method indicating the number of synthetic rows
    that you want to generate.

In [2]:
from sdv.tabular import CTGAN

model = CTGAN()
model.fit(data)

<div class="alert alert-info">

**Note**

Notice that the model `fitting` process took care of transforming the
different fields using the appropriate [Reversible Data
Transforms](http://github.com/sdv-dev/RDT) to ensure that the data has a
format that the underlying CTGANSynthesizer class can handle.

</div>

### Generate synthetic data from the model

Once the modeling has finished you are ready to generate new synthetic
data by calling the `sample` method from your model passing the number
of rows that we want to generate.

In [3]:
new_data = model.sample(200)

This will return a table identical to the one which the model was fitted
on, but filled with new data which resembles the original one.

In [4]:
new_data.head()

Unnamed: 0,student_id,gender,second_perc,high_perc,high_spec,degree_perc,degree_type,work_experience,experience_years,employability_perc,mba_spec,mba_perc,salary,placed,start_date,end_date,duration
0,17395,F,55.151151,69.21726,Science,75.055894,Comm&Mgmt,True,0,83.628662,Mkt&Fin,52.770144,27870.59516,True,2020-03-18,NaT,12.0
1,17352,F,50.531162,103.005005,Arts,57.469543,Sci&Tech,False,0,80.459705,Mkt&HR,69.130408,,True,2020-09-16,2020-11-13,3.0
2,17308,F,70.149168,70.505952,Science,60.46002,Sci&Tech,True,0,94.729369,Mkt&HR,66.688696,29128.897528,False,2020-03-03,NaT,3.0
3,17281,F,60.441178,68.69775,Science,66.136382,Comm&Mgmt,False,0,72.358689,Mkt&Fin,65.408333,28382.651974,False,2020-01-20,NaT,
4,17222,M,43.908954,60.566542,Science,83.90779,Comm&Mgmt,False,0,100.004639,Mkt&Fin,67.801198,41273.372854,True,2020-02-18,NaT,


<div class="alert alert-info">

**Note**

You can control the number of rows by specifying the number of `samples`
in the `model.sample(<num_rows>)`. To test, try `model.sample(10000)`.
Note that the original table only had \~200 rows.

</div>

### Save and Load the model

In many scenarios it will be convenient to generate synthetic versions
of your data directly in systems that do not have access to the original
data source. For example, if you may want to generate testing data on
the fly inside a testing environment that does not have access to your
production database. In these scenarios, fitting the model with real
data every time that you need to generate new data is feasible, so you
will need to fit a model in your production environment, save the fitted
model into a file, send this file to the testing environment and then
load it there to be able to `sample` from it.

Let\'s see how this process works.

#### Save and share the model

Once you have fitted the model, all you need to do is call its `save`
method passing the name of the file in which you want to save the model.
Note that the extension of the filename is not relevant, but we will be
using the `.pkl` extension to highlight that the serialization protocol
used is [pickle](https://docs.python.org/3/library/pickle.html).

In [5]:
model.save('my_model.pkl')

This will have created a file called `my_model.pkl` in the same
directory in which you are running SDV.

<div class="alert alert-info">

**Important**

If you inspect the generated file you will notice that its size is much
smaller than the size of the data that you used to generate it. This is
because the serialized model contains **no information about the
original data**, other than the parameters it needs to generate
synthetic versions of it. This means that you can safely share this
`my_model.pkl` file without the risc of disclosing any of your real
data!

</div>

#### Load the model and generate new data

The file you just generated can be send over to the system where the
synthetic data will be generated. Once it is there, you can load it
using the `CTGAN.load` method, and then you are ready to sample new data
from the loaded instance:

In [6]:
loaded = CTGAN.load('my_model.pkl')
new_data = loaded.sample(200)

<div class="alert alert-warning">

**Warning**

Notice that the system where the model is loaded needs to also have
`sdv` and `ctgan` installed, otherwise it will not be able to load the
model and use it.

</div>

### Specifying the Primary Key of the table

One of the first things that you may have noticed when looking that demo
data is that there is a `student_id` column which acts as the primary
key of the table, and which is supposed to have unique values. Indeed,
if we look at the number of times that each value appears, we see that
all of them appear at most once:

In [7]:
data.student_id.value_counts().max()

1

However, if we look at the synthetic data that we generated, we observe
that there are some values that appear more than once:

In [8]:
new_data[new_data.student_id == new_data.student_id.value_counts().index[0]]

Unnamed: 0,student_id,gender,second_perc,high_perc,high_spec,degree_perc,degree_type,work_experience,experience_years,employability_perc,mba_spec,mba_perc,salary,placed,start_date,end_date,duration
15,17341,F,52.103283,67.305936,Science,65.063683,Comm&Mgmt,False,1,104.80813,Mkt&Fin,69.473722,27488.58143,False,NaT,NaT,3.0
21,17341,M,49.72142,43.397368,Science,70.047069,Comm&Mgmt,False,0,86.087502,Mkt&HR,59.883408,,True,2020-07-20,2021-02-25,12.0
25,17341,M,44.931732,53.931788,Commerce,56.92488,Comm&Mgmt,False,0,93.203709,Mkt&Fin,55.704451,28673.551492,True,2020-01-30,2020-09-20,6.0
74,17341,M,57.661074,65.337233,Science,68.107331,Sci&Tech,False,0,77.698156,Mkt&Fin,61.471423,56153.296134,False,2020-03-18,2020-11-19,6.0


This happens because the model was not notified at any point about the
fact that the `student_id` had to be unique, so when it generates new
data it will provoke collisions sooner or later. In order to solve this,
we can pass the argument `primary_key` to our model when we create it,
indicating the name of the column that is the index of the table.

In [9]:
model = CTGAN(
    primary_key='student_id'
)
model.fit(data)
new_data = model.sample(200)
new_data.head()

Unnamed: 0,student_id,gender,second_perc,high_perc,high_spec,degree_perc,degree_type,work_experience,experience_years,employability_perc,mba_spec,mba_perc,salary,placed,start_date,end_date,duration
0,0,F,73.813341,62.244292,Commerce,70.411162,Comm&Mgmt,True,0,71.096949,Mkt&HR,46.254781,28940.194301,True,NaT,NaT,
1,1,F,51.601912,56.722433,Arts,78.659046,Sci&Tech,False,0,76.711212,Mkt&HR,64.377626,24499.026036,True,2020-03-07,2020-07-05,
2,2,F,68.035751,69.471137,Science,60.318167,Sci&Tech,False,0,102.576115,Mkt&HR,59.438222,27088.484079,False,2020-03-10,2020-08-13,3.0
3,3,F,64.046875,67.482545,Commerce,50.847634,Others,False,0,81.527101,Mkt&HR,55.814594,,False,NaT,2020-07-30,
4,4,M,73.125859,69.976795,Science,63.217668,Comm&Mgmt,False,0,93.404223,Mkt&HR,52.598417,28179.468679,True,2020-06-01,2020-06-02,3.0


As a result, the model will learn that this column must be unique and
generate a unique sequence of values for the column:

In [10]:
new_data.student_id.value_counts().max()

1

### Anonymizing Personally Identifiable Information (PII)

There will be many cases where the data will contain Personally
Identifiable Information which we cannot disclose. In these cases, we
will want our Tabular Models to replace the information within these
fields with fake, simulated data that looks similar to the real one but
does not contain any of the original values.

Let\'s load a new dataset that contains a PII field, the
`student_placements_pii` demo, and try to generate synthetic versions of
it that do not contain any of the PII fields.

<div class="alert alert-info">

**Note**

The `student_placements_pii` dataset is a modified version of the
`student_placements` dataset with one new field, `address`, which
contains PII information about the students. Notice that this additional
`address` field has been simulated and does not correspond to data from
the real users.

</div>

In [11]:
data_pii = load_tabular_demo('student_placements_pii')
data_pii.head()

Unnamed: 0,student_id,address,gender,second_perc,high_perc,high_spec,degree_perc,degree_type,work_experience,experience_years,employability_perc,mba_spec,mba_perc,salary,placed,start_date,end_date,duration
0,17264,"70304 Baker Turnpike\nEricborough, MS 15086",M,67.0,91.0,Commerce,58.0,Sci&Tech,False,0,55.0,Mkt&HR,58.8,27000.0,True,2020-07-23,2020-10-12,3.0
1,17265,"805 Herrera Avenue Apt. 134\nMaryview, NJ 36510",M,79.33,78.33,Science,77.48,Sci&Tech,True,1,86.5,Mkt&Fin,66.28,20000.0,True,2020-01-11,2020-04-09,3.0
2,17266,"3702 Bradley Island\nNorth Victor, FL 12268",M,65.0,68.0,Arts,64.0,Comm&Mgmt,False,0,75.0,Mkt&Fin,57.8,25000.0,True,2020-01-26,2020-07-13,6.0
3,17267,Unit 0879 Box 3878\nDPO AP 42663,M,56.0,52.0,Science,52.0,Sci&Tech,False,0,66.0,Mkt&HR,59.43,,False,NaT,NaT,
4,17268,"96493 Kelly Canyon Apt. 145\nEast Steven, NC 3...",M,85.8,73.6,Commerce,73.3,Comm&Mgmt,False,0,96.8,Mkt&Fin,55.5,42500.0,True,2020-07-04,2020-09-27,3.0


If we use our tabular model on this new data we will see how the
synthetic data that it generates discloses the addresses from the real
students:

In [12]:
model = CTGAN(
    primary_key='student_id',
)
model.fit(data_pii)
new_data_pii = model.sample(200)
new_data_pii.head()

Unnamed: 0,student_id,address,gender,second_perc,high_perc,high_spec,degree_perc,degree_type,work_experience,experience_years,employability_perc,mba_spec,mba_perc,salary,placed,start_date,end_date,duration
0,0,"75756 Anthony Prairie\nNorth Charlene, CA 60546",M,63.077572,38.787614,Science,78.271952,Sci&Tech,False,0,79.355913,Mkt&Fin,60.56443,,False,NaT,NaT,
1,1,"185 Jasmine Ferry Apt. 294\nPort Nancy, NY 68773",M,67.684419,63.19147,Arts,82.399588,Sci&Tech,True,0,64.040811,Mkt&Fin,50.508109,25111.787617,True,NaT,2020-07-27,
2,2,"8236 William Prairie Apt. 451\nRobertatown, ND...",M,95.420389,55.596998,Commerce,66.731186,Comm&Mgmt,False,1,51.898863,Mkt&Fin,56.078154,25930.497044,True,NaT,NaT,6.0
3,3,044 Natasha Squares Apt. 005\nEast Michaelboro...,M,87.387236,77.707899,Science,87.17799,Comm&Mgmt,False,0,94.599228,Mkt&HR,54.74486,,True,NaT,NaT,12.0
4,4,"61808 Catherine Center\nEast Lonnieport, OK 54464",F,83.925508,38.800122,Science,56.352591,Comm&Mgmt,True,0,67.684155,Mkt&Fin,63.619684,25534.168837,True,2020-06-14,2020-07-23,


More specifically, we can see how all the addresses that have been
generated actually come from the original dataset:

In [13]:
new_data_pii.address.isin(data_pii.address).sum()

200

In order to solve this, we can pass an additional argument
`anonymize_fields` to our model when we create the instance. This
`anonymize_fields` argument will need to be a dictionary that contains:

-   The name of the field that we want to anonymize.
-   The category of the field that we want to use when we generate fake
    values for it.

The list complete list of possible categories can be seen in the [Faker
Providers](https://faker.readthedocs.io/en/master/providers.html) page,
and it contains a huge list of concepts such as:

-   name
-   address
-   country
-   city
-   ssn
-   credit_card_number
-   credit_card_expire
-   credit_card_security_code
-   email
-   telephone
-   \...

In this case, since the field is an e-mail address, we will pass a
dictionary indicating the category `address`

In [14]:
model = CTGAN(
    primary_key='student_id',
    anonymize_fields={
        'address': 'address'
    }
)
model.fit(data_pii)

As a result, we can see how the real `address` values have been replaced
by other fake addresses:

In [15]:
new_data_pii = model.sample(200)
new_data_pii.head()

Unnamed: 0,student_id,address,gender,second_perc,high_perc,high_spec,degree_perc,degree_type,work_experience,experience_years,employability_perc,mba_spec,mba_perc,salary,placed,start_date,end_date,duration
0,0,"272 Turner Forest Suite 987\nNorth Eric, NV 03200",M,74.725396,57.642797,Commerce,57.532698,Sci&Tech,False,0,85.663334,Mkt&HR,67.051413,43246.61162,True,2020-01-19,2020-08-22,6.0
1,1,"6861 Ashley Neck\nPort Nicoleville, WV 72346",F,44.403354,62.440652,Science,52.518439,Comm&Mgmt,True,0,64.155841,Mkt&Fin,78.187649,31471.609046,True,NaT,2020-09-14,3.0
2,2,"401 Ryan Harbors\nMoorestad, TN 01585",F,77.851161,67.601876,Science,48.965949,Comm&Mgmt,True,0,68.60245,Mkt&HR,64.768897,41941.232075,True,2020-06-19,NaT,
3,3,"978 Keith Pines Suite 446\nRileyland, DC 35771",M,60.395844,67.275818,Science,44.648989,Comm&Mgmt,False,0,79.761588,Mkt&Fin,59.002028,27319.962285,True,NaT,2021-02-20,
4,4,USNS Paul\nFPO AP 92961,F,63.887682,68.896234,Science,70.598204,Comm&Mgmt,False,1,54.076206,Mkt&Fin,57.258452,48942.497936,True,2020-03-29,2020-11-01,3.0


Which means that none of the original addresses can be found in the
sampled data:

In [16]:
data_pii.address.isin(new_data_pii.address).sum()

0

Advanced Usage
--------------

Now that we have discovered the basics, let\'s go over a few more
advanced usage examples and see the different arguments that we can pass
to our `CTGAN` Model in order to customize it to our needs.

### How to modify the CTGAN Hyperparameters?

A part from the common Tabular Model arguments, `CTGAN` has a number of
additional hyperparameters that control its learning behavior and can
impact on the performance of the model, both in terms of quality of the
generated data and computational time.

-   `epochs` and `batch_size`: these arguments control the number of
    iterations that the model will perform to optimize its parameters,
    as well as the number of samples used in each step. Its default
    values are `300` and `500` respectively, and `batch_size` needs to
    always be a value which is multiple of `10`. These hyperparameters
    have a very direct effect in time the training process lasts but
    also on the performance of the data, so for new datasets, you might
    want to start by setting a low value on both of them to see how long
    the training process takes on your data and later on increase the
    number to acceptable values in order to improve the performance.
-   `log_frequency`: Whether to use log frequency of categorical levels
    in conditional sampling. It defaults to `True`. This argument affects
    how the model processes the frequencies of the categorical values that
    are used to condition the rest of the values. In some cases, changing
    it to `False` could lead to better performance.
-   `embedding_dim` (int): Size of the random sample passed to the
    Generator. Defaults to 128.
-   `generator_dim` (tuple or list of ints): Size of the output samples for
    each one of the Residuals. A Resiudal Layer will be created for each
    one of the values provided. Defaults to (256, 256).
-   `discriminator_dim` (tuple or list of ints): Size of the output samples for
    each one of the Discriminator Layers. A Linear Layer will be created
    for each one of the values provided. Defaults to (256, 256).
-   `generator_lr` (float): Learning rate for the generator. Defaults to 2e-4.
-   `generator_decay` (float): Generator weight decay for the Adam Optimizer.
    Defaults to 1e-6.
-   `discriminator_lr` (float): Learning rate for the discriminator.
    Defaults to 2e-4.
-   `discriminator_decay` (float): Discriminator weight decay for the Adam
    Optimizer. Defaults to 1e-6.
-   `discriminator_steps` (int): Number of discriminator updates to do for
    each generator update. From the WGAN paper: https://arxiv.org/abs/1701.07875.
    WGAN paper default is 5. Default used is 1 to match original CTGAN
    implementation.
-   `verbose`: Whether to print fit progress on stdout. Defaults to
    `False`.

<div class="alert alert-warning">

**Warning**

Notice that the value that you set on the `batch_size` argument must
always be a multiple of `10`!

</div>

As an example, we will try to fit the `CTGAN` model slightly increasing
the number of epochs, reducing the `batch_size`, adding one additional
layer to the models involved and using a smaller wright decay.

Before we start, we will evaluate the quality of the previously
generated data using the `sdv.evaluation.evaluate` function

In [17]:
from sdv.evaluation import evaluate

evaluate(new_data, data)

0.7227906796224464

Afterwards, we create a new instance of the `CTGAN` model with the
hyperparameter values that we want to use

In [18]:
model = CTGAN(
    primary_key='student_id',
    epochs=500,
    batch_size=100,
    generator_dim=(256, 256, 256),
    discriminator_dim=(256, 256, 256)
)

And fit to our data.

In [19]:
model.fit(data)

Finally, we are ready to generate new data and evaluate the results.

In [20]:
new_data = model.sample(len(data))
evaluate(new_data, data)

0.726067972824495

As we can see, in this case these modifications changed the obtained
results slightly, but they did neither introduce dramatic changes in the
performance.

### How do I specify constraints?

If you look closely at the data you may notice that some properties were
not completely captured by the model. For example, you may have seen
that sometimes the model produces an `experience_years` number greater
than `0` while also indicating that `work_experience` is `False`. These
type of properties are what we call `Constraints` and can also be
handled using `SDV`. For further details about them please visit the
[Handling Constraints](04_Handling_Constraints.ipynb) guide.

### Can I evaluate the Synthetic Data?

A very common question when someone starts using **SDV** to generate
synthetic data is: *\"How good is the data that I just generated?\"*

In order to answer this question, **SDV** has a collection of metrics
and tools that allow you to compare the *real* that you provided and the
*synthetic* data that you generated using **SDV** or any other tool.

You can read more about this in the [Evaluating Synthetic Data Generators](
05_Evaluating_Synthetic_Data_Generators.ipynb) guide.