In [16]:
import pandas as pd
from sdv.metadata import SingleTableMetadata
from sdv.single_table import CTGANSynthesizer, TVAESynthesizer
from sdv.evaluation.single_table import evaluate_quality
from sdv.sampling import Condition


In [2]:
train = pd.read_csv('../data/5_train_dataset.csv')
train.head()

Unnamed: 0,is_fraud,ERC20_num_unique_senders_to_acc,total_transactions,ERC20_num_unique_senders_contract_to_acc,total_value_sent_ratio,avg_value_sent,num_received_to_total_txns_ratio,min_value_sent,ERC20_num_unique_recipients_from_acc,num_unique_senders_to_acc,...,num_unique_recipients_from_acc,first_and_last_txns_time_diff,total_value_received_ratio,has_sent_ERC20,max_value_sent,avg_time_between_received_txns,min_value_received,avg_time_between_sent_txns,total_ERC20_sent_contract_in_ether,ERC20_max_value_sent
0,0,0.0,111,0.0,0.500672,0.156104,0.504505,0.015341,0.0,3,...,2,450875.98,0.499328,0,1.001666,8048.14,0.015761,3.28,0.0,0.0
1,0,0.0,260,0.0,0.500388,0.270446,0.5,0.02558,0.0,2,...,2,136967.03,0.499612,0,0.791452,957.18,0.026,96.42,0.0,0.0
2,0,0.0,6,0.0,0.500005,25.24947,0.333333,0.004454,0.0,2,...,2,331.23,0.499995,0,100.41452,0.02,26.376409,82.8,0.0,0.0
3,1,0.0,7,0.0,0.0,0.0,0.857143,0.0,0.0,3,...,0,176158.57,0.0,0,0.0,29359.76,0.0,0.0,0.0,0.0
4,1,0.0,6,0.0,0.0,0.0,0.833333,0.0,0.0,3,...,0,173765.32,0.0,0,0.0,34753.06,0.0,0.0,0.0,0.0


In [3]:
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(data=train)
metadata

{
    "METADATA_SPEC_VERSION": "SINGLE_TABLE_V1",
    "columns": {
        "is_fraud": {
            "sdtype": "categorical"
        },
        "ERC20_num_unique_senders_to_acc": {
            "sdtype": "numerical"
        },
        "total_transactions": {
            "sdtype": "numerical"
        },
        "ERC20_num_unique_senders_contract_to_acc": {
            "sdtype": "numerical"
        },
        "total_value_sent_ratio": {
            "sdtype": "numerical"
        },
        "avg_value_sent": {
            "sdtype": "numerical"
        },
        "num_received_to_total_txns_ratio": {
            "sdtype": "numerical"
        },
        "min_value_sent": {
            "sdtype": "numerical"
        },
        "ERC20_num_unique_recipients_from_acc": {
            "sdtype": "numerical"
        },
        "num_unique_senders_to_acc": {
            "sdtype": "numerical"
        },
        "max_value_received": {
            "sdtype": "numerical"
        },
        "total_ether_rec

In [5]:
synthesizer_ctgan = CTGANSynthesizer(metadata,verbose=True, epochs=500)
synthesizer_ctgan.fit(train)

Gen. (-0.53) | Discrim. (0.35): 100%|███████| 300/300 [4:02:59<00:00, 48.60s/it]


In [None]:
synthesizer_tvae = TVAESynthesizer(metadata, verbose=True, epochs=500)
synthesizer_tvae.fit(train)

In [30]:
diff = train['is_fraud'].value_counts()[0] - train['is_fraud'].value_counts()[1]
diff

3275

In [31]:
def get_fraud_cond(num_rows):

    return Condition(
        num_rows=num_rows,
        column_values={'is_fraud': 1}
    )

def get_non_fraud_cond(num_rows):

    return Condition(
        num_rows=num_rows,
        column_values={'is_fraud': 0}
    )

In [34]:
# Balance the dataset

synthetic_data_fraud_ctgan = synthesizer_ctgan.sample_from_conditions(
    conditions=[get_fraud_cond(25000+diff)])

synthetic_data_non_fraud_ctgan = synthesizer_ctgan.sample_from_conditions(
    conditions=[get_non_fraud_cond(25000)])

synthetic_data_ctgan = pd.concat([synthetic_data_fraud_ctgan, synthetic_data_non_fraud_ctgan])

Sampling conditions: 100%|███████████████| 28275/28275 [01:32<00:00, 304.09it/s]
Sampling conditions: 100%|███████████████| 25000/25000 [00:45<00:00, 546.05it/s]


In [35]:
# Balance the dataset

synthetic_data_fraud_tvae = synthesizer_tvae.sample_from_conditions(
    conditions=[get_fraud_cond(25000+diff)])

synthetic_data_non_fraud_tvae = synthesizer_tvae.sample_from_conditions(
    conditions=[get_non_fraud_cond(25000)])

synthetic_data_tvae = pd.concat([synthetic_data_fraud_tvae, synthetic_data_non_fraud_tvae])

1    28275
0    25000
Name: is_fraud, dtype: int64

In [None]:
synthetic_data_ctgan = synthetic_data_ctgan.reset_index(drop=True)
synthetic_data_tvae = synthetic_data_tvae.reset_index(drop=True)

In [36]:
# Quality Report - CTGAN

quality_report = evaluate_quality(
    real_data=train,
    synthetic_data=synthetic_data_fraud_ctgan,
    metadata=metadata)


Generating report ...
(1/2) Evaluating Column Shapes: : 100%|████████| 22/22 [00:00<00:00, 117.35it/s]
(2/2) Evaluating Column Pair Trends: : 100%|█| 231/231 [00:01<00:00, 168.49it/s]

Overall Quality Score: 74.7%

Properties:
- Column Shapes: 63.31%
- Column Pair Trends: 86.09%


In [None]:
# Quality Report - TVAE

quality_report = evaluate_quality(
    real_data=train,
    synthetic_data=synthetic_data_fraud_tvae,
    metadata=metadata)


In [None]:
quality_report = evaluate_quality(
    real_data=train,
    synthetic_data=synthetic_data_non_fraud_ctgan,
    metadata=metadata)

In [37]:
quality_report = evaluate_quality(
    real_data=train,
    synthetic_data=synthetic_data_non_fraud_tvae,
    metadata=metadata)

Generating report ...
(1/2) Evaluating Column Shapes: : 100%|████████| 22/22 [00:00<00:00, 115.77it/s]
(2/2) Evaluating Column Pair Trends: : 100%|█| 231/231 [00:01<00:00, 177.68it/s]

Overall Quality Score: 80.89%

Properties:
- Column Shapes: 70.18%
- Column Pair Trends: 91.61%


In [42]:
quality_report = evaluate_quality(
    real_data=train,
    synthetic_data=synthetic_data_ctgan,
    metadata=metadata)

In [43]:
quality_report = evaluate_quality(
    real_data=train,
    synthetic_data=synthetic_data_tvae,
    metadata=metadata)

Generating report ...

  0%|                                                    | 0/22 [00:00<?, ?it/s][A
(1/2) Evaluating Column Shapes: :   0%|                  | 0/22 [00:00<?, ?it/s][A
(1/2) Evaluating Column Shapes: :  32%|███▏      | 7/22 [00:00<00:00, 65.70it/s][A
(1/2) Evaluating Column Shapes: : 100%|█████████| 22/22 [00:00<00:00, 68.16it/s][A

  0%|                                                   | 0/231 [00:00<?, ?it/s][A
(2/2) Evaluating Column Pair Trends: :   0%|            | 0/231 [00:00<?, ?it/s][A
(2/2) Evaluating Column Pair Trends: :   1%|    | 2/231 [00:00<00:12, 18.34it/s][A
(2/2) Evaluating Column Pair Trends: :   3%|    | 6/231 [00:00<00:08, 27.80it/s][A
(2/2) Evaluating Column Pair Trends: :   4%|▏  | 10/231 [00:00<00:07, 30.60it/s][A
(2/2) Evaluating Column Pair Trends: :   6%|▏  | 14/231 [00:00<00:06, 31.88it/s][A
(2/2) Evaluating Column Pair Trends: :   8%|▏  | 18/231 [00:00<00:06, 32.39it/s][A
(2/2) Evaluating Column Pair Trends: :  10%|▎  | 24/

In [47]:
augmented_data = pd.concat([train, synthetic_data]).reset_index(drop=True)
augmented_data

Unnamed: 0,is_fraud,ERC20_num_unique_senders_to_acc,total_transactions,ERC20_num_unique_senders_contract_to_acc,total_value_sent_ratio,avg_value_sent,num_received_to_total_txns_ratio,min_value_sent,ERC20_num_unique_recipients_from_acc,num_unique_senders_to_acc,...,num_unique_recipients_from_acc,first_and_last_txns_time_diff,total_value_received_ratio,has_sent_ERC20,max_value_sent,avg_time_between_received_txns,min_value_received,avg_time_between_sent_txns,total_ERC20_sent_contract_in_ether,ERC20_max_value_sent
0,0,0.0,111,0.0,0.500672,0.156104,0.504505,0.015341,0.0,3,...,2,450875.980000,0.499328,0,1.001666,8048.140000,0.015761,3.28,0.000000,0.000000e+00
1,0,0.0,260,0.0,0.500388,0.270446,0.500000,0.025580,0.0,2,...,2,136967.030000,0.499612,0,0.791452,957.180000,0.026000,96.42,0.000000,0.000000e+00
2,0,0.0,6,0.0,0.500005,25.249470,0.333333,0.004454,0.0,2,...,2,331.230000,0.499995,0,100.414520,0.020000,26.376409,82.80,0.000000,0.000000e+00
3,1,0.0,7,0.0,0.000000,0.000000,0.857143,0.000000,0.0,3,...,0,176158.570000,0.000000,0,0.000000,29359.760000,0.000000,0.00,0.000000,0.000000e+00
4,1,0.0,6,0.0,0.000000,0.000000,0.833333,0.000000,0.0,3,...,0,173765.320000,0.000000,0,0.000000,34753.060000,0.000000,0.00,0.000000,0.000000e+00
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
59159,0,0.0,679,0.0,0.498970,32.665619,0.503103,1.037651,0.0,6,...,4,5108.817305,0.500093,0,0.000000,19963.557698,0.000000,186.70,17.771658,8.048208e+06
59160,0,0.0,21,0.0,0.494286,34.675133,0.498906,4.763361,1.0,18,...,3,515829.464344,0.497344,0,8.066305,462.529426,114.588304,466.18,23.534153,1.158387e+07
59161,0,0.0,57,0.0,0.498770,0.000000,0.503460,0.335458,1.0,7,...,6,364168.673151,0.497342,0,48.845611,8408.506005,3.623419,0.00,0.000000,0.000000e+00
59162,0,1.0,2781,0.0,0.230138,0.000000,0.375489,0.000000,36.0,5,...,7,864095.182810,0.496848,1,78.060760,0.000000,5.337699,5154.11,0.000000,1.909525e+06


In [48]:
augmented_data['is_fraud'].value_counts()

0    29582
1    29582
Name: is_fraud, dtype: int64

In [49]:
augmented_data.to_csv("../data/6_train_dataset_augmented.csv", index=False)

In [None]:
synthesizer_ctgan.save('../models/ctgan.pkl')
synthesizer_tvae.save('../models/tvae.pkl')