In [None]:
#!pip install sdv ctgan torch

In [28]:
import pandas as pd
from tvae import TVAE

# 타이타닉 데이터셋 로드
url = 'https://web.stanford.edu/class/archive/cs/cs109/cs109.1166/stuff/titanic.csv'
data = pd.read_csv(url)

# 필요한 전처리 (예: 결측값 처리)
data = data.drop(['Name'], axis=1)  # 분석에 필요하지 않은 'Name' 열 삭제
data = data.dropna()  # 결측값이 있는 행 삭제
data['Sex'] = data['Sex'].map({'male': 0, 'female': 1})  # 성별을 숫자로 변환

# 데이터셋 확인
print(data.head())

   Survived  Pclass  Sex   Age  Siblings/Spouses Aboard  \
0         0       3    0  22.0                        1   
1         1       1    1  38.0                        1   
2         1       3    1  26.0                        0   
3         1       1    1  35.0                        1   
4         0       3    0  35.0                        0   

   Parents/Children Aboard     Fare  
0                        0   7.2500  
1                        0  71.2833  
2                        0   7.9250  
3                        0  53.1000  
4                        0   8.0500  


In [29]:
# TVAE 모델 학습
tvae = TVAE(epochs=300, batch_size=500, verbose=True)


In [30]:
tvae.fit(data, discrete_columns=['Survived', 'Pclass', 'Sex', 'Siblings/Spouses Aboard', 'Parents/Children Aboard'])



Loss: 6.078: 100%|██████████| 300/300 [00:07<00:00, 40.53it/s]


In [31]:
# 새로운 데이터 생성
tvae_new_data = tvae.sample(10)
print(tvae_new_data)

   Survived  Pclass  Sex        Age  Siblings/Spouses Aboard  \
0         1       3    1  20.736773                        0   
1         0       3    0  22.513544                        0   
2         0       1    0  27.765766                        1   
3         0       3    0  18.670739                        0   
4         1       3    1  20.442953                        0   
5         0       3    0  30.124512                        0   
6         0       3    0  22.857849                        0   
7         1       2    1  36.740552                        0   
8         0       3    0  15.053903                        0   
9         0       3    1  25.575785                        0   

   Parents/Children Aboard       Fare  
0                        0  13.236631  
1                        0   7.856270  
2                        0  24.288985  
3                        0  57.603609  
4                        0  16.937087  
5                        0  13.163455  
6              

In [32]:
from sdv.single_table import CTGANSynthesizer
from sdv.metadata import SingleTableMetadata
# CTGAN 모델 초기화
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(data)
model = CTGANSynthesizer(metadata)

# 모델 학습
model.fit(data)





In [33]:
# 새로운 데이터 생성 (예: 10개의 샘플 생성)
CTGAN_new_data = model.sample(10)

# 생성된 데이터 출력
print(CTGAN_new_data)

   Survived  Pclass  Sex    Age  Siblings/Spouses Aboard  \
0         0       3    0  11.90                        1   
1         0       3    0  19.63                        0   
2         1       1    1  54.80                        4   
3         0       3    1  17.49                        0   
4         1       3    1  55.91                        4   
5         1       3    0  37.95                        1   
6         0       2    1  65.24                        0   
7         0       1    0  24.67                        0   
8         1       3    0  28.69                        0   
9         0       3    1   0.42                        1   

   Parents/Children Aboard      Fare  
0                        2    7.1253  
1                        0   13.6876  
2                        0   27.4948  
3                        0   16.0078  
4                        1  114.1140  
5                        1   14.5457  
6                        0   17.1474  
7                        2 