- https://arxiv.org/abs/1811.11264
- https://github.com/DAI-Lab/TGAN
- https://dai-lab.github.io/TGAN/

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.core.display import display, HTML 
display(HTML("<style>.container { width:100% !important; }</style>"))

In [2]:
df = pd.read_csv('input/train.csv')
df.head()

Unnamed: 0,PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
0,1,0,3,"Braund, Mr. Owen Harris",male,22.0,1,0,A/5 21171,7.25,,S
1,2,1,1,"Cumings, Mrs. John Bradley (Florence Briggs Th...",female,38.0,1,0,PC 17599,71.2833,C85,C
2,3,1,3,"Heikkinen, Miss. Laina",female,26.0,0,0,STON/O2. 3101282,7.925,,S
3,4,1,1,"Futrelle, Mrs. Jacques Heath (Lily May Peel)",female,35.0,1,0,113803,53.1,C123,S
4,5,0,3,"Allen, Mr. William Henry",male,35.0,0,0,373450,8.05,,S


In [3]:
df.shape

(891, 12)

# TGANはNaNに対応していない

In [4]:
df.isnull().sum()

PassengerId      0
Survived         0
Pclass           0
Name             0
Sex              0
Age            177
SibSp            0
Parch            0
Ticket           0
Fare             0
Cabin          687
Embarked         2
dtype: int64

In [5]:
df.drop('Cabin', axis=1, inplace=True)

In [6]:
avg_age = round(df['Age'].mean(), 0)
avg_age

30.0

In [7]:
df['Embarked'].value_counts()

S    644
C    168
Q     77
Name: Embarked, dtype: int64

In [8]:
df['Age'].fillna(avg_age, inplace=True)
df['Embarked'].fillna(df['Embarked'].value_counts().index[0], inplace=True)

In [9]:
df.isnull().sum()

PassengerId    0
Survived       0
Pclass         0
Name           0
Sex            0
Age            0
SibSp          0
Parch          0
Ticket         0
Fare           0
Embarked       0
dtype: int64

# TGANに入れないcolumnsは削除しておく

In [10]:
df.drop(['PassengerId','Name', 'Ticket'], axis=1, inplace=True)

In [11]:
df.head()

Unnamed: 0,Survived,Pclass,Sex,Age,SibSp,Parch,Fare,Embarked
0,0,3,male,22.0,1,0,7.25,S
1,1,1,female,38.0,1,0,71.2833,C
2,1,3,female,26.0,0,0,7.925,S
3,1,1,female,35.0,1,0,53.1,S
4,0,3,male,35.0,0,0,8.05,S


# TGANでcolumnsがindexに破壊的に置き換えられてしまうので、保存しておく

In [12]:
df_columns = df.columns

In [13]:
continuous_columns = [df.columns.get_loc(c) for c in df.select_dtypes(include=['float']).columns]
continuous_columns

[3, 6]

In [14]:
from tgan.model import TGANModel

W0819 02:17:06.711143 140735518651264 deprecation_wrapper.py:119] From /Users/shotaroishihara/.pyenv/versions/anaconda3-5.0.0/lib/python3.6/site-packages/tensorpack/callbacks/hooks.py:17: The name tf.train.SessionRunHook is deprecated. Please use tf.estimator.SessionRunHook instead.

W0819 02:17:06.724061 140735518651264 deprecation_wrapper.py:119] From /Users/shotaroishihara/.pyenv/versions/anaconda3-5.0.0/lib/python3.6/site-packages/tensorpack/tfutils/optimizer.py:18: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.

W0819 02:17:06.726121 140735518651264 deprecation_wrapper.py:119] From /Users/shotaroishihara/.pyenv/versions/anaconda3-5.0.0/lib/python3.6/site-packages/tensorpack/tfutils/sesscreate.py:20: The name tf.train.SessionCreator is deprecated. Please use tf.compat.v1.train.SessionCreator instead.



In [15]:
%%time
# batch_sizeを小さめに指定しないと、tensorpack の assertion error で止まる
# https://github.com/tensorpack/tensorpack/blob/8112723601610a6a3a6211c9893bee23942c0848/tensorpack/dataflow/common.py#L90
tgan = TGANModel(continuous_columns, batch_size=50)
tgan.fit(df)

W0819 02:17:07.181424 140735518651264 deprecation_wrapper.py:119] From /Users/shotaroishihara/.pyenv/versions/anaconda3-5.0.0/lib/python3.6/site-packages/tensorpack/graph_builder/model_desc.py:29: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.

W0819 02:17:07.182371 140735518651264 deprecation_wrapper.py:119] From /Users/shotaroishihara/.pyenv/versions/anaconda3-5.0.0/lib/python3.6/site-packages/tensorpack/graph_builder/model_desc.py:39: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

W0819 02:17:07.189288 140735518651264 deprecation_wrapper.py:119] From /Users/shotaroishihara/.pyenv/versions/anaconda3-5.0.0/lib/python3.6/site-packages/tensorpack/input_source/input_source.py:219: The name tf.FIFOQueue is deprecated. Please use tf.queue.FIFOQueue instead.



CPU times: user 2 µs, sys: 0 ns, total: 2 µs
Wall time: 5.25 µs
[32m[0819 02:17:07 @input_source.py:222][0m Setting up the queue 'QueueInput/input_queue' for CPU prefetching ...


W0819 02:17:07.196660 140735518651264 deprecation_wrapper.py:119] From /Users/shotaroishihara/.pyenv/versions/anaconda3-5.0.0/lib/python3.6/site-packages/tensorpack/tfutils/summary.py:237: The name tf.get_variable_scope is deprecated. Please use tf.compat.v1.get_variable_scope instead.

W0819 02:17:07.197694 140735518651264 deprecation_wrapper.py:119] From /Users/shotaroishihara/.pyenv/versions/anaconda3-5.0.0/lib/python3.6/site-packages/tensorpack/tfutils/summary.py:27: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.

W0819 02:17:07.230877 140735518651264 deprecation_wrapper.py:119] From /Users/shotaroishihara/.pyenv/versions/anaconda3-5.0.0/lib/python3.6/site-packages/tensorpack/tfutils/summary.py:264: The name tf.summary.scalar is deprecated. Please use tf.compat.v1.summary.scalar instead.

W0819 02:17:07.244670 140735518651264 deprecation.py:323] From /Users/shotaroishihara/.pyenv/versions/anaconda3-5.0.0/lib/python3.6/site-packages/tgan/mo

[32m[0819 02:17:07 @registry.py:126][0m gen/LSTM/00/FC input: [50, 100]


W0819 02:17:07.814757 140735518651264 deprecation_wrapper.py:119] From /Users/shotaroishihara/.pyenv/versions/anaconda3-5.0.0/lib/python3.6/site-packages/tensorpack/models/fc.py:57: The name tf.layers.Dense is deprecated. Please use tf.compat.v1.layers.Dense instead.



[32m[0819 02:17:08 @registry.py:134][0m gen/LSTM/00/FC output: [50, 100]
[32m[0819 02:17:08 @registry.py:126][0m gen/LSTM/00/FC2 input: [50, 100]
[32m[0819 02:17:08 @registry.py:134][0m gen/LSTM/00/FC2 output: [50, 2]
[32m[0819 02:17:08 @registry.py:126][0m gen/LSTM/00/FC3 input: [50, 2]
[32m[0819 02:17:08 @registry.py:134][0m gen/LSTM/00/FC3 output: [50, 100]
[32m[0819 02:17:08 @registry.py:126][0m gen/LSTM/01/FC input: [50, 100]
[32m[0819 02:17:08 @registry.py:134][0m gen/LSTM/01/FC output: [50, 100]
[32m[0819 02:17:08 @registry.py:126][0m gen/LSTM/01/FC2 input: [50, 100]
[32m[0819 02:17:08 @registry.py:134][0m gen/LSTM/01/FC2 output: [50, 3]
[32m[0819 02:17:08 @registry.py:126][0m gen/LSTM/01/FC3 input: [50, 3]
[32m[0819 02:17:08 @registry.py:134][0m gen/LSTM/01/FC3 output: [50, 100]
[32m[0819 02:17:08 @registry.py:126][0m gen/LSTM/02/FC input: [50, 100]
[32m[0819 02:17:08 @registry.py:134][0m gen/LSTM/02/FC output: [50, 100]
[32m[0819 02:17:08 @registry.p

W0819 02:17:09.521713 140735518651264 deprecation_wrapper.py:119] From /Users/shotaroishihara/.pyenv/versions/anaconda3-5.0.0/lib/python3.6/site-packages/tensorpack/models/batch_norm.py:204: The name tf.layers.BatchNormalization is deprecated. Please use tf.compat.v1.layers.BatchNormalization instead.



[32m[0819 02:17:09 @registry.py:126][0m discrim/dis_fc_top input: [50, 110]
[32m[0819 02:17:09 @registry.py:134][0m discrim/dis_fc_top output: [50, 1]


W0819 02:17:09.718101 140735518651264 deprecation_wrapper.py:119] From /Users/shotaroishihara/.pyenv/versions/anaconda3-5.0.0/lib/python3.6/site-packages/tgan/model.py:118: The name tf.summary.histogram is deprecated. Please use tf.compat.v1.summary.histogram instead.

W0819 02:17:09.726330 140735518651264 deprecation.py:323] From /Users/shotaroishihara/.pyenv/versions/anaconda3-5.0.0/lib/python3.6/site-packages/tensorflow/python/ops/nn_impl.py:180: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
W0819 02:17:10.307646 140735518651264 lazy_loader.py:50] 
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensor

[32m[0819 02:17:13 @logger.py:90][0m Argv: /Users/shotaroishihara/.pyenv/versions/anaconda3-5.0.0/lib/python3.6/site-packages/ipykernel_launcher.py -f /Users/shotaroishihara/Library/Jupyter/runtime/kernel-0d4da767-c599-45d8-8a3d-021d24cc11d8.json


W0819 02:17:13.281690 140735518651264 deprecation_wrapper.py:119] From /Users/shotaroishihara/.pyenv/versions/anaconda3-5.0.0/lib/python3.6/site-packages/tensorpack/callbacks/saver.py:43: The name tf.gfile.IsDirectory is deprecated. Please use tf.io.gfile.isdir instead.

W0819 02:17:13.282520 140735518651264 deprecation_wrapper.py:119] From /Users/shotaroishihara/.pyenv/versions/anaconda3-5.0.0/lib/python3.6/site-packages/tensorpack/callbacks/saver.py:44: The name tf.gfile.MakeDirs is deprecated. Please use tf.io.gfile.makedirs instead.



[32m[0819 02:17:13 @model_utils.py:67][0m [36mList of Trainable Variables: 
[0mname                              shape         #elements
--------------------------------  ----------  -----------
gen/LSTM/go:0                     [1, 100]            100
gen/LSTM/lstm_cell/kernel:0       [500, 400]       200000
gen/LSTM/lstm_cell/bias:0         [400]               400
gen/LSTM/00/FC/W:0                [100, 100]        10000
gen/LSTM/00/FC/b:0                [100]               100
gen/LSTM/00/FC2/W:0               [100, 2]            200
gen/LSTM/00/FC2/b:0               [2]                   2
gen/LSTM/00/FC3/W:0               [2, 100]            200
gen/LSTM/00/FC3/b:0               [100]               100
gen/LSTM/00/attw:0                [1, 1, 1]             1
gen/LSTM/01/FC/W:0                [100, 100]        10000
gen/LSTM/01/FC/b:0                [100]               100
gen/LSTM/01/FC2/W:0               [100, 3]            300
gen/LSTM/01/FC2/b:0               [3]          

W0819 02:17:13.295694 140735518651264 deprecation_wrapper.py:119] From /Users/shotaroishihara/.pyenv/versions/anaconda3-5.0.0/lib/python3.6/site-packages/tensorpack/callbacks/graph.py:54: The name tf.train.SessionRunArgs is deprecated. Please use tf.estimator.SessionRunArgs instead.



[32m[0819 02:17:13 @summary.py:46][0m [MovingAverageSummary] 6 operations in collection 'MOVING_SUMMARY_OPS' will be run with session hooks.
[32m[0819 02:17:13 @summary.py:93][0m Summarizing collection 'summaries' of size 9.


W0819 02:17:13.492088 140735518651264 deprecation_wrapper.py:119] From /Users/shotaroishihara/.pyenv/versions/anaconda3-5.0.0/lib/python3.6/site-packages/tensorpack/callbacks/summary.py:94: The name tf.summary.merge_all is deprecated. Please use tf.compat.v1.summary.merge_all instead.



[32m[0819 02:17:13 @graph.py:98][0m Applying collection UPDATE_OPS of 4 ops.


W0819 02:17:13.498045 140735518651264 deprecation_wrapper.py:119] From /Users/shotaroishihara/.pyenv/versions/anaconda3-5.0.0/lib/python3.6/site-packages/tensorpack/callbacks/monitor.py:261: The name tf.summary.FileWriter is deprecated. Please use tf.compat.v1.summary.FileWriter instead.



[32m[0819 02:17:14 @base.py:230][0m Creating the session ...
[32m[0819 02:17:15 @base.py:236][0m Initializing the session ...
[32m[0819 02:17:15 @base.py:243][0m Graph Finalized.
[32m[0819 02:17:15 @concurrency.py:38][0m Starting EnqueueThread QueueInput/input_queue ...


W0819 02:17:15.522864 140735518651264 deprecation_wrapper.py:119] From /Users/shotaroishihara/.pyenv/versions/anaconda3-5.0.0/lib/python3.6/site-packages/tensorpack/callbacks/monitor.py:309: The name tf.gfile.Exists is deprecated. Please use tf.io.gfile.exists instead.



[32m[0819 02:17:15 @base.py:275][0m Start Epoch 1 ...


100%|##########|10000/10000[02:35<00:00,64.41it/s]

[32m[0819 02:19:50 @base.py:285][0m Epoch 1 (global_step 10000) finished, time:2 minutes 35 seconds.
[32m[0819 02:19:50 @saver.py:79][0m Model saved to output/model/model-10000.
[32m[0819 02:19:50 @monitor.py:467][0m GAN_loss/discrim/accuracy_fake: 0.78
[32m[0819 02:19:50 @monitor.py:467][0m GAN_loss/discrim/accuracy_real: 0.3
[32m[0819 02:19:50 @monitor.py:467][0m GAN_loss/discrim/loss: 0.67256
[32m[0819 02:19:50 @monitor.py:467][0m GAN_loss/gen/final-g-loss: 1.0307
[32m[0819 02:19:50 @monitor.py:467][0m GAN_loss/gen/klloss: 0.16364
[32m[0819 02:19:50 @monitor.py:467][0m GAN_loss/gen/loss: 0.86706
[32m[0819 02:19:50 @monitor.py:467][0m QueueInput/queue_size: 50
[32m[0819 02:19:50 @base.py:275][0m Start Epoch 2 ...



100%|##########|10000/10000[02:31<00:00,65.98it/s]

[32m[0819 02:22:22 @base.py:285][0m Epoch 2 (global_step 20000) finished, time:2 minutes 31 seconds.
[32m[0819 02:22:22 @saver.py:79][0m Model saved to output/model/model-20000.
[32m[0819 02:22:22 @monitor.py:467][0m GAN_loss/discrim/accuracy_fake: 0.92
[32m[0819 02:22:22 @monitor.py:467][0m GAN_loss/discrim/accuracy_real: 0.3
[32m[0819 02:22:22 @monitor.py:467][0m GAN_loss/discrim/loss: 0.61909
[32m[0819 02:22:22 @monitor.py:467][0m GAN_loss/gen/final-g-loss: 1.1143
[32m[0819 02:22:22 @monitor.py:467][0m GAN_loss/gen/klloss: 0.12303
[32m[0819 02:22:22 @monitor.py:467][0m GAN_loss/gen/loss: 0.99129
[32m[0819 02:22:22 @monitor.py:467][0m QueueInput/queue_size: 50
[32m[0819 02:22:22 @base.py:275][0m Start Epoch 3 ...



100%|##########|10000/10000[02:34<00:00,64.53it/s]

[32m[0819 02:24:57 @base.py:285][0m Epoch 3 (global_step 30000) finished, time:2 minutes 34 seconds.
[32m[0819 02:24:57 @saver.py:79][0m Model saved to output/model/model-30000.
[32m[0819 02:24:57 @monitor.py:467][0m GAN_loss/discrim/accuracy_fake: 0.96
[32m[0819 02:24:57 @monitor.py:467][0m GAN_loss/discrim/accuracy_real: 0.2
[32m[0819 02:24:57 @monitor.py:467][0m GAN_loss/discrim/loss: 0.62615
[32m[0819 02:24:57 @monitor.py:467][0m GAN_loss/gen/final-g-loss: 1.1615
[32m[0819 02:24:57 @monitor.py:467][0m GAN_loss/gen/klloss: 0.16487
[32m[0819 02:24:57 @monitor.py:467][0m GAN_loss/gen/loss: 0.99665
[32m[0819 02:24:57 @monitor.py:467][0m QueueInput/queue_size: 50
[32m[0819 02:24:57 @base.py:275][0m Start Epoch 4 ...



100%|##########|10000/10000[02:32<00:00,65.74it/s]

[32m[0819 02:27:29 @base.py:285][0m Epoch 4 (global_step 40000) finished, time:2 minutes 32 seconds.
[32m[0819 02:27:29 @saver.py:79][0m Model saved to output/model/model-40000.
[32m[0819 02:27:29 @monitor.py:467][0m GAN_loss/discrim/accuracy_fake: 0.86
[32m[0819 02:27:29 @monitor.py:467][0m GAN_loss/discrim/accuracy_real: 0.34
[32m[0819 02:27:29 @monitor.py:467][0m GAN_loss/discrim/loss: 0.64826
[32m[0819 02:27:29 @monitor.py:467][0m GAN_loss/gen/final-g-loss: 1.0523
[32m[0819 02:27:29 @monitor.py:467][0m GAN_loss/gen/klloss: 0.12617
[32m[0819 02:27:29 @monitor.py:467][0m GAN_loss/gen/loss: 0.92613
[32m[0819 02:27:29 @monitor.py:467][0m QueueInput/queue_size: 50
[32m[0819 02:27:29 @base.py:275][0m Start Epoch 5 ...



100%|##########|10000/10000[02:27<00:00,67.99it/s]

[32m[0819 02:29:56 @base.py:285][0m Epoch 5 (global_step 50000) finished, time:2 minutes 27 seconds.
[32m[0819 02:29:56 @saver.py:79][0m Model saved to output/model/model-50000.
[32m[0819 02:29:56 @monitor.py:467][0m GAN_loss/discrim/accuracy_fake: 0.9
[32m[0819 02:29:56 @monitor.py:467][0m GAN_loss/discrim/accuracy_real: 0.24
[32m[0819 02:29:56 @monitor.py:467][0m GAN_loss/discrim/loss: 0.63292
[32m[0819 02:29:56 @monitor.py:467][0m GAN_loss/gen/final-g-loss: 1.3043
[32m[0819 02:29:56 @monitor.py:467][0m GAN_loss/gen/klloss: 0.2792
[32m[0819 02:29:56 @monitor.py:467][0m GAN_loss/gen/loss: 1.0251
[32m[0819 02:29:56 @monitor.py:467][0m QueueInput/queue_size: 50
[32m[0819 02:29:56 @base.py:289][0m Training has finished!





[32m[0819 02:29:57 @input_source.py:178][0m EnqueueThread QueueInput/input_queue Exited.
[32m[0819 02:29:58 @collection.py:146][0m New collections created in tower : tf.GraphKeys.REGULARIZATION_LOSSES
[32m[0819 02:29:58 @collection.py:165][0m These collections were modified but restored in : (tf.GraphKeys.SUMMARIES: 0->2)
[32m[0819 02:29:58 @sessinit.py:87][0m [5m[31mWRN[0m The following variables are in the checkpoint, but not found in the graph: global_step, optimize/beta1_power, optimize/beta2_power
[32m[0819 02:29:58 @sessinit.py:114][0m Restoring checkpoint from output/model/model-50000 ...


W0819 02:29:58.625829 140735518651264 deprecation.py:323] From /Users/shotaroishihara/.pyenv/versions/anaconda3-5.0.0/lib/python3.6/site-packages/tensorflow/python/training/saver.py:1276: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.


In [16]:
df.head()

Unnamed: 0,0,1,2,3,4,5,6,7
0,0,3,male,22.0,1,0,7.25,S
1,1,1,female,38.0,1,0,71.2833,C
2,1,3,female,26.0,0,0,7.925,S
3,1,1,female,35.0,1,0,53.1,S
4,0,3,male,35.0,0,0,8.05,S


In [17]:
num_samples = len(df)
samples = tgan.sample(num_samples)

 32%|###2      |16/50[00:00<00:00,68.75it/s]


In [18]:
samples.columns = df_columns
samples.head()

Unnamed: 0,Survived,Pclass,Sex,Age,SibSp,Parch,Fare,Embarked
0,0,2,female,13.196098,3,0,20.713007,S
1,0,3,male,29.998602,0,0,11.851699,S
2,0,2,male,25.893719,0,0,10.928357,S
3,0,3,male,14.37834,0,0,9.984689,S
4,0,2,male,25.540975,0,0,13.816646,S


In [19]:
samples['Survived'].value_counts()

0    540
1    310
Name: Survived, dtype: int64

In [20]:
model_path = 'output/models/mymodel.pkl'
tgan.save(model_path)

[32m[0819 02:49:01 @model.py:813][0m Model saved successfully.
