In [1]:
!pip install sdv



### Load dataset

In [15]:
from sdv.demo import load_tabular_demo
import pandas as pd

data = pd.read_csv('data/infringement_dataset_v2.csv')
data.head()


Unnamed: 0,loan_id,infringed,contract_type,gender,has_own_car,has_own_realty,num_children,annual_income,credit_amount,credit_annuity,...,first_name,last_name,past_avg_amount_annuity,past_avg_amt_application,past_avg_amt_credit,past_loans_approved,past_loans_refused,past_loans_canceled,past_loans_unused,past_loans_total
0,100002,1,Cash loans,M,N,Y,0,202500.0,406597.5,24700.5,...,Robert,Watkins,9251.775,179055.0,179055.0,1.0,0.0,0.0,0.0,1.0
1,100003,0,Cash loans,F,N,N,0,270000.0,1293502.5,35698.5,...,Jane,Navarro,56553.99,435436.5,484191.0,3.0,0.0,0.0,0.0,3.0
2,100004,0,Revolving loans,M,Y,Y,0,67500.0,135000.0,6750.0,...,David,Seagraves,5357.25,24282.0,20106.0,1.0,0.0,0.0,0.0,1.0
3,100006,0,Cash loans,F,N,Y,0,135000.0,312682.5,29686.5,...,Deborah,Tandy,23651.175,272203.26,291695.5,5.0,3.0,1.0,0.0,9.0
4,100007,0,Cash loans,M,N,Y,0,121500.0,513000.0,21865.5,...,David,Walker,12278.805,150530.25,166638.75,6.0,0.0,0.0,0.0,6.0


### Fit CTGAN to the data (train the model)
CTGAN paper: https://arxiv.org/pdf/1907.00503.pdf

In [17]:
import warnings
warnings.filterwarnings('ignore')

from sdv.tabular import GaussianCopula

model = GaussianCopula()
model.fit(data)

### Create synthetic samples using the trained model

In [19]:
new_data = model.sample(num_rows=200)
new_data.head(5)

Unnamed: 0,loan_id,infringed,contract_type,gender,has_own_car,has_own_realty,num_children,annual_income,credit_amount,credit_annuity,...,first_name,last_name,past_avg_amount_annuity,past_avg_amt_application,past_avg_amt_credit,past_loans_approved,past_loans_refused,past_loans_canceled,past_loans_unused,past_loans_total
0,345607,0,Cash loans,F,N,N,1,134540.0,765580.0,28584.5,...,Joan,Ford,27519.0,304298.0,299738.0,1.0,2.0,2.0,0.0,4.0
1,134652,0,Cash loans,F,N,Y,1,362026.0,652081.0,13033.8,...,Latoya,Osborn,4066.0,32385.0,68989.0,3.0,1.0,1.0,1.0,4.0
2,232974,1,Cash loans,F,N,Y,0,243592.0,736022.0,19096.7,...,Peter,Threlkeld,5017.0,76384.0,60709.0,1.0,1.0,0.0,0.0,2.0
3,252313,0,Cash loans,F,N,N,0,459844.0,727017.0,32507.7,...,Maurice,Hamilton,,,,,,,,
4,445411,0,Cash loans,M,Y,Y,1,333030.0,512096.0,24031.2,...,Gary,Ray,13627.0,147402.0,145684.0,3.0,0.0,1.0,1.0,4.0


### Evaluate synthetic data

In [20]:
from sdv.evaluation import evaluate
overall_score = evaluate(new_data, data)
overall_score

0.9068501041906748

### Quality report

In [21]:
from sdmetrics.reports.single_table import QualityReport
my_report = QualityReport()
my_report.generate(data, new_data, model.get_metadata().to_dict())


Creating report: 100%|██████████| 4/4 [17:37<00:00, 264.41s/it]



Overall Quality Score: 81.72%

Properties:
Column Shapes: 80.84%
Column Pair Trends: 82.6%


### Column comparison

In [22]:
my_report.get_details(property_name='Column Shapes')

Unnamed: 0,Column,Metric,Quality Score
0,loan_id,KSComplement,0.962514
1,infringed,KSComplement,0.930729
2,num_children,KSComplement,0.729632
3,annual_income,KSComplement,0.698224
4,credit_amount,KSComplement,0.873148
5,credit_annuity,KSComplement,0.926044
6,goods_valuation,KSComplement,0.883255
7,age,KSComplement,0.955033
8,days_employed,KSComplement,0.250072
9,car_age,KSComplement,0.859108


### Correlation pairs

In [23]:
my_report.get_visualization(property_name='Column Pair Trends')

### Distribution comparison

In [25]:
from sdmetrics.reports.utils import get_column_plot

fig = get_column_plot(
    real_data=data,
    synthetic_data=new_data,
    metadata=model.get_metadata().to_dict(),
    column_name='age'
)

fig.show()

In [28]:
from sdmetrics.reports.utils import get_column_plot

fig = get_column_plot(
    real_data=data,
    synthetic_data=new_data,
    metadata=model.get_metadata().to_dict(),
    column_name='infringed'
)

fig.show()